# Colab-efficientnet-lightning

GridMixup: [IlyaDobrynin/GridMixup](https://github.com/IlyaDobrynin/GridMixup)

Diffaug: [mit-han-lab/data-efficient-gans](https://github.com/mit-han-lab/data-efficient-gans)

AdamP: [clovaai/AdamP](https://github.com/clovaai/AdamP)

EfficientNet repo: [lukemelas/EfficientNet-PyTorch](https://github.com/lukemelas/EfficientNet-PyTorch)

Original repo: [bentrevett/pytorch-image-classification](https://github.com/bentrevett/pytorch-image-classification)

My fork: [styler00dollar/Colab-image-classification](https://github.com/styler00dollar/Colab-image-classification)

Currently trains without validation and displays training-only metrics. Saves model as pytorch ``pth``. Uses GridMix loss as default.

In [None]:
!nvidia-smi

In [None]:
#@title Install
!pip install efficientnet_pytorch
!pip install adamp
!pip install pytorch_lightning
!pip install tensorboardX

In [None]:
#@title print means and stds
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from tqdm import tqdm
train_dir = '/content/data/' #@param
train_data = datasets.ImageFolder(root = train_dir, 
                                  transform = transforms.ToTensor())

means = torch.zeros(3)
stds = torch.zeros(3)

for img, label in tqdm(train_data):
    means += torch.mean(img, dim = (1,2))
    stds += torch.std(img, dim = (1,2))

means /= len(train_data)
stds /= len(train_data)

print("\n")
print(f'Calculated means: {means}')
print(f'Calculated stds: {stds}')

In [None]:
#@title init.py
import torch.nn.init as init

def weights_init(net, init_type = 'kaiming', init_gain = 0.02):
    #Initialize network weights.
    #Parameters:
    #    net (network)       -- network to be initialized
    #    init_type (str)     -- the name of an initialization method: normal | xavier | kaiming | orthogonal
    #    init_var (float)    -- scaling factor for normal, xavier and orthogonal.

    def init_func(m):
        classname = m.__class__.__name__

        if hasattr(m, 'weight') and classname.find('Conv') != -1:
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain = init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a = 0, mode = 'fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain = init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, 0.02)
            init.constant_(m.bias.data, 0.0)
        elif classname.find('Linear') != -1:
            init.normal_(m.weight, 0, 0.01)
            init.constant_(m.bias, 0)

    # Apply the initialization function <init_func>
    print('Initialization method [{:s}]'.format(init_type))
    net.apply(init_func)

In [None]:
#@title dataloader.py
from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import torchvision.datasets as datasets

class DataModule(pl.LightningDataModule):
    def __init__(self, training_path: str = './', validation_path: str = './', test_path: str = './', batch_size: int = 5, num_workers: int = 2, size = 256):
        super().__init__()
        self.training_dir = training_path
        self.validation_dir = validation_path
        self.test_dir = test_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.size = size

        self.means = [0.6013, 0.5443, 0.5521] #@param {type:"raw"}
        self.std = [0.2496, 0.2509, 0.2433] #@param {type:"raw"}

    def setup(self, stage=None):
        img_tf = transforms.Compose([
            transforms.Resize(size=self.size),
            transforms.RandomRotation(5),
            #transforms.CenterCrop(size=self.size),
            transforms.RandomCrop(self.size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean = self.means, 
                                  std = self.std)
        ])
        
        self.DS_train = datasets.ImageFolder(root = self.training_dir, 
                                  transform = img_tf)
        self.DS_validation = datasets.ImageFolder(root = self.training_dir, 
                                  transform = img_tf)
        self.DS_test = datasets.ImageFolder(root = self.training_dir, 
                                  transform = img_tf)

    def train_dataloader(self):
        return DataLoader(self.DS_train, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.DS_validation, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.DS_test, batch_size=self.batch_size, num_workers=self.num_workers)

In [None]:
#@title GridMixupLoss.py
"""
gridmix_pytorch.py (27-3-20)
https://github.com/IlyaDobrynin/GridMixup/blob/main/gridmix/gridmix_pytorch.py
"""
import typing as t
import random
import numpy as np
import math
import torch
from torch import nn


class GridMixupLoss(nn.Module):
    """ Implementation of GridMixup loss

    :param alpha: Percent of the first image on the crop. Can be float or Tuple[float, float]
                    - if float: lambda parameter gets from the beta-distribution np.random.beta(alpha, alpha)
                    - if Tuple[float, float]: lambda parameter gets from the uniform
                        distribution np.random.uniform(alpha[0], alpha[1])
    :param n_holes_x: Number of holes by OX
    :param hole_aspect_ratio: hole aspect ratio
    :param crop_area_ratio: Define percentage of the crop area
    :param crop_aspect_ratio: Define crop aspect ratio
    """
    def __init__(
            self,
            alpha: t.Union[float, t.Tuple[float, float]] = (0.1, 0.9),
            n_holes_x: t.Union[int, t.Tuple[int, int]] = 20,
            hole_aspect_ratio: t.Union[float, t.Tuple[float, float]] = 1.,
            crop_area_ratio: t.Union[float, t.Tuple[float, float]] = 1.,
            crop_aspect_ratio: t.Union[float, t.Tuple[float, float]] = 1.,
    ):

        super().__init__()
        self.alpha = alpha
        self.n_holes_x = n_holes_x
        self.hole_aspect_ratio = hole_aspect_ratio
        self.crop_area_ratio = crop_area_ratio
        self.crop_aspect_ratio = crop_aspect_ratio
        if isinstance(self.n_holes_x, int):
            self.n_holes_x = (self.n_holes_x, self.n_holes_x)
        if isinstance(self.hole_aspect_ratio, float):
            self.hole_aspect_ratio = (self.hole_aspect_ratio, self.hole_aspect_ratio)
        if isinstance(self.crop_area_ratio, float):
            self.crop_area_ratio = (self.crop_area_ratio, self.crop_area_ratio)
        if isinstance(self.crop_aspect_ratio, float):
            self.crop_aspect_ratio = (self.crop_aspect_ratio, self.crop_aspect_ratio)

        self.loss = nn.CrossEntropyLoss()

    def __str__(self):
        return "gridmixup"

    @staticmethod
    def _get_random_crop(height: int, width: int, crop_area_ratio: float, crop_aspect_ratio: float) -> t.Tuple:
        crop_area = int(height * width * crop_area_ratio)
        crop_width = int(np.sqrt(crop_area / crop_aspect_ratio))
        crop_height = int(crop_width * crop_aspect_ratio)

        cx = np.random.random()
        cy = np.random.random()

        y1 = int((height - crop_height) * cy)
        y2 = y1 + crop_height
        x1 = int((width - crop_width) * cx)
        x2 = x1 + crop_width
        return x1, y1, x2, y2

    def _get_gridmask(
            self,
            image_shape: t.Tuple[int, int],
            crop_area_ratio: float,
            crop_aspect_ratio: float,
            lam: float,
            nx: int,
            ar: float,
    ) -> np.ndarray:
        """ Method make grid mask

        :param image_shape: Shape of the images
        :param lam: Lambda parameter
        :param crop_area_ratio: Ratio of the crop area
        :param crop_aspect_ratio: Aspect ratio of the crop
        :param nx: Amount of holes by width
        :param ar: Aspect ratio of the hole
        :return: Binary mask, where holes == 1, background == 0
        """
        img_height, img_width = image_shape

        # Get coordinates of random box
        xc1, yc1, xc2, yc2 = self._get_random_crop(
            height=img_height,
            width=img_width,
            crop_area_ratio=crop_area_ratio,
            crop_aspect_ratio=crop_aspect_ratio
        )
        height = yc2 - yc1
        width = xc2 - xc1

        if not 1 <= nx <= width // 2:
            raise ValueError(
                f"The nx must be between 1 and {width // 2}.\n"
                f"Give: {nx}"
            )

        # Get patch width, height and ny
        patch_width = math.ceil(width / nx)
        patch_height = int(patch_width * ar)
        ny = math.ceil(height / patch_height)

        # Calculate ratio of the hole - percent of hole pixels in the patch
        ratio = np.sqrt(1 - lam)

        # Get hole size
        hole_width = int(patch_width * ratio)
        hole_height = int(patch_height * ratio)

        # min 1 pixel and max patch length - 1
        hole_width = min(max(hole_width, 1), patch_width - 1)
        hole_height = min(max(hole_height, 1), patch_height - 1)

        # Make grid mask
        holes = []
        for i in range(nx + 1):
            for j in range(ny + 1):
                x1 = min(patch_width * i, width)
                y1 = min(patch_height * j, height)
                x2 = min(x1 + hole_width, width)
                y2 = min(y1 + hole_height, height)
                holes.append((x1, y1, x2, y2))

        mask = np.zeros(shape=image_shape, dtype=np.uint8)
        for x1, y1, x2, y2 in holes:
            mask[yc1+y1:yc1+y2, xc1+x1:xc1+x2] = 1
        return mask

    def get_sample(self, images: torch.Tensor, targets: torch.Tensor) -> t.Tuple[torch.Tensor, torch.Tensor]:
        """ Method returns augmented images and targets

        :param images: Batch of non-augmented images
        :param targets: Batch of non-augmented targets
        :return: Augmented images and targets
        """
        # Get new indices
        indices = torch.randperm(images.size(0)).to(images.device)

        # Shuffle labels
        shuffled_targets = targets[indices].to(targets.device)

        # Get image shape
        height, width = images.shape[2:]

        # Get lambda
        if isinstance(self.alpha, float):
            lam = np.random.beta(self.alpha, self.alpha)
        else:
            lam = np.random.uniform(self.alpha[0], self.alpha[1])

        nx = random.randint(self.n_holes_x[0], self.n_holes_x[1])
        ar = np.random.uniform(self.hole_aspect_ratio[0], self.hole_aspect_ratio[1])
        crop_area_ratio = np.random.uniform(self.crop_area_ratio[0], self.crop_area_ratio[1])
        crop_aspect_ratio = np.random.uniform(self.crop_aspect_ratio[0], self.crop_aspect_ratio[1])
        mask = self._get_gridmask(
            image_shape=(height, width),
            crop_area_ratio=crop_area_ratio,
            crop_aspect_ratio=crop_aspect_ratio,
            lam=lam,
            nx=nx,
            ar=ar
        )
        # Adjust lambda to exactly match pixel ratio
        lam = 1 - (mask.sum() / (images.size()[-1] * images.size()[-2]))

        # Make shuffled images
        mask = torch.from_numpy(mask).to(targets.device)
        images = images * (1 - mask) + images[indices, ...] * mask

        # Prepare out labels
        lam_list = torch.from_numpy(np.ones(shape=targets.shape) * lam).to(targets.device)
        out_targets = torch.cat([targets, shuffled_targets, lam_list], dim=1).transpose(0, 1)
        return images, out_targets

    def forward(self, preds: torch.Tensor, trues: torch.Tensor) -> torch.Tensor:
        lam = trues[-1][0].float()
        trues1, trues2 = trues[0].long(), trues[1].long()
        loss = self.loss(preds, trues1) * lam + self.loss(preds, trues2) * (1 - lam)
        return loss


In [None]:
#@title diffaug.py
"""
DiffAugment_pytorch.py (27-3-20)
https://github.com/mit-han-lab/data-efficient-gans/blob/master/DiffAugment_pytorch.py
"""
# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738

import torch
import torch.nn.functional as F


def DiffAugment(x, policy='', channels_first=True):
    if policy:
        if not channels_first:
            x = x.permute(0, 3, 1, 2)
        for p in policy.split(','):
            for f in AUGMENT_FNS[p]:
                x = f(x)
        if not channels_first:
            x = x.permute(0, 2, 3, 1)
        x = x.contiguous()
    return x


def rand_brightness(x):
    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
    return x


def rand_saturation(x):
    x_mean = x.mean(dim=1, keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
    return x


def rand_contrast(x):
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
    return x


def rand_translation(x, ratio=0.125):
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
    return x


def rand_cutout(x, ratio=0.5):
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x = x * mask.unsqueeze(1)
    return x


AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'translation': [rand_translation],
    'cutout': [rand_cutout],
}

In [None]:
#@title accuracy.py
from sklearn.metrics import accuracy_score

def calc_accuracy(preds: torch.Tensor, trues: torch.Tensor) -> float:
    lam = trues[-1, :][0].data.cpu().numpy()
    true_label = [trues[0, :].long(), trues[1, :].long()]
    trues = true_label[0] if lam > 0.5 else true_label[1]
    trues = trues.data.cpu().numpy().astype(np.uint8)
    preds = torch.softmax(preds, dim=1).float()
    preds = np.argmax(preds.data.cpu().numpy(), axis=1).astype(np.uint8)
    metric = accuracy_score(trues, preds)
    return float(metric)

In [None]:
#@title CustomTrainClass.py
from efficientnet_pytorch import EfficientNet
from adamp import AdamP
#from adamp import SGDP
import numpy as np

class CustomTrainClass(pl.LightningModule):
  def __init__(self):
    super().__init__()
    model_train = 'efficientnet-b0' #@param ["efficientnet-b0", "efficientnet-b1", "efficientnet-b2", "efficientnet-b3", "efficientnet-b4", "efficientnet-b5", "efficientnet-b6", "efficientnet-b7"] {type:"string"}
    num_classes = 2 #@param
    if model_train == 'efficientnet-b0':
      self.netD = EfficientNet.from_pretrained('efficientnet-b0', num_classes=num_classes)
    elif model_train == 'efficientnet-b1':
      self.netD = EfficientNet.from_pretrained('efficientnet-b1', num_classes=num_classes)
    elif model_train == 'efficientnet-b2':
      self.netD = EfficientNet.from_pretrained('efficientnet-b2', num_classes=num_classes)
    elif model_train == 'efficientnet-b3':
      self.netD = EfficientNet.from_pretrained('efficientnet-b3', num_classes=num_classes)
    elif model_train == 'efficientnet-b4':
      self.netD = EfficientNet.from_pretrained('efficientnet-b4', num_classes=num_classes)
    elif model_train == 'efficientnet-b5':
      self.netD = EfficientNet.from_pretrained('efficientnet-b5', num_classes=num_classes)
    elif model_train == 'efficientnet-b6':
      self.netD = EfficientNet.from_pretrained('efficientnet-b6', num_classes=num_classes)
    elif model_train == 'efficientnet-b7':
      self.netD = EfficientNet.from_pretrained('efficientnet-b7', num_classes=num_classes)

    #weights_init(self.netD, 'kaiming') #only use this if there is no pretrain

    self.criterion = GridMixupLoss(
        alpha=(0.4, 0.7),
        hole_aspect_ratio=1.,
        crop_area_ratio=(0.5, 1),
        crop_aspect_ratio=(0.5, 2),
        n_holes_x=(2, 6)
    )
    self.accuracy = []
    self.losses = []
    self.diffaug_activate = False #@param

    #@markdown Supports ``'color,translation,cutout'``
    self.policy = 'color' #@param

  def training_step(self, train_batch, batch_idx):
    inputs, targets = self.criterion.get_sample(images=train_batch[0], targets=train_batch[1].unsqueeze(-1))  
    targets = targets -1 # fixing range

    if self.diffaug_activate == False:
      preds = self.netD(inputs)
    else:
      preds = self.netD(DiffAugment(inputs, policy=self.policy))

    # Calculate loss
    loss = self.criterion(preds, targets) 

    self.accuracy.append(calc_accuracy(preds, targets))
    self.losses.append(loss.item())
    return loss  

  def configure_optimizers(self):
      #optimizer = torch.optim.Adam(self.netD.parameters(), lr=2e-3)
      optimizer = AdamP(self.netD.parameters(), lr=2e-4, betas=(0.9, 0.999), weight_decay=1e-2)
      #optimizer = SGDP(self.netD.parameters(), lr=0.1, weight_decay=1e-5, momentum=0.9, nesterov=True)
      return optimizer

  def training_epoch_end(self, training_step_outputs):
      loss_mean = np.mean(self.losses)
      accuracy_mean = np.mean(self.accuracy)
      print(f"'Epoch': {self.current_epoch}, 'loss': {loss_mean}, 'accuracy': {accuracy_mean}")
      
      # logging
      self.log('loss_mean', loss_mean, prog_bar=True, logger=True, on_epoch=True)
      self.log('accuracy_mean', accuracy_mean, prog_bar=True, logger=True, on_epoch=True)

      self.losses = []
      self.accuracy = []

      torch.save(trainer.model.netD.state_dict(), f"Checkpoint_{self.current_epoch}_{self.global_step}_loss_{loss_mean:3f}_acc_{accuracy_mean:3f}_D.pth")

  def validation_step(self, train_batch, train_idx):
      print("not implemented")
  def test_step(self, train_batch, train_idx):
      print("not implemented")

In [None]:
#@title training
dm = DataModule(batch_size=2, training_path='/content/data/', num_workers = 1, size = 256)
model = CustomTrainClass()
# skipping validation with limit_val_batches=0
trainer = pl.Trainer(limit_val_batches=0, gpus=1, max_epochs=150, progress_bar_refresh_rate=20, default_root_dir='/content/')
trainer.fit(model, dm)

# Test

In [None]:
#@title sort files with predictions (configured for 2 classes)
import torch
import glob
from efficientnet_pytorch import EfficientNet
import cv2
import torch.nn.functional as F
import shutil
import os
from tqdm import tqdm

model_train = 'efficientnet-b0' #@param ["efficientnet-b0", "efficientnet-b1", "efficientnet-b2", "efficientnet-b3", "efficientnet-b4", "efficientnet-b5", "efficientnet-b6", "efficientnet-b7"] {type:"string"}
num_classes = 2 #@param
if model_train == 'efficientnet-b0':
  model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=num_classes)
elif model_train == 'efficientnet-b1':
  model = EfficientNet.from_pretrained('efficientnet-b1', num_classes=num_classes)
elif model_train == 'efficientnet-b2':
  model = EfficientNet.from_pretrained('efficientnet-b2', num_classes=num_classes)
elif model_train == 'efficientnet-b3':
  model = EfficientNet.from_pretrained('efficientnet-b3', num_classes=num_classes)
elif model_train == 'efficientnet-b4':
  model = EfficientNet.from_pretrained('efficientnet-b4', num_classes=num_classes)
elif model_train == 'efficientnet-b5':
  model = EfficientNet.from_pretrained('efficientnet-b5', num_classes=num_classes)
elif model_train == 'efficientnet-b6':
  model = EfficientNet.from_pretrained('efficientnet-b6', num_classes=num_classes)
elif model_train == 'efficientnet-b7':
  model = EfficientNet.from_pretrained('efficientnet-b7', num_classes=num_classes)

model_path = '/content/test.pth' #@param {type:"string"}
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

rootdir = '/content/data' #@param {type:"string"}

# probably depends on how categories are alphabetically sorted
path0 = '/content/output/0' #@param {type:"string"}
path1 = '/content/output/1' #@param {type:"string"}

if not os.path.exists(path0):
    os.makedirs(path0)
if not os.path.exists(path1):
    os.makedirs(path1)

files = glob.glob(rootdir + '/**/*.png', recursive=True)
files_jpg = glob.glob(rootdir + '/**/*.jpg', recursive=True)
files.extend(files_jpg)

model.to(device)

height_min = 256
width_min = 256

with torch.no_grad():
  for f in tqdm(files):
      image = cv2.imread(f)
      #image = cv2.resize(image, (256,256))

      # resizing to match original training, or detections will be bad
      height = image.shape[0]
      width = image.shape[1]
      if height > height_min and width > width_min:
          height_resized = height_min
          if width < height:
            scale_x = width_min/width
            width_resized = width_min
            height_resized = scale_x * height
          else:
            scale_y = height_min/height
            height_resized = height_min
            width_resized = scale_y * width
            image = cv2.resize(image, (round(width_resized), round(height_resized)))
      elif height <= height_min or width <= width_min:
          if height > width:
              width_resized = width_min
              scale = width_min/width
              height_resized = height*scale
              image = cv2.resize(image, (round(width_resized), round(height_resized)))
          else:
              height_resized = height_min
              scale = height_min/height
              width_resized = width*scale
              image = cv2.resize(image, (round(width_resized), round(height_resized)))

      image = torch.from_numpy(image).unsqueeze(0).permute(0,3,1,2)/255
      
      image=image.to(device)

      y_pred = model(image)

      #y_prob = F.softmax(y_pred, dim = -1)
      #top_pred = y_prob.argmax(1, keepdim = True)

      y_pred = torch.softmax(y_pred, dim=1).float()
      top_pred = np.argmax(y_pred.data.cpu().numpy(), axis=1).astype(np.uint8)


      if top_pred == 0:
        shutil.move(f, os.path.join(path0, os.path.basename(f)))
      elif top_pred == 1:
        shutil.move(f, os.path.join(path1, os.path.basename(f)))

----------------------------------------------------------------------

# TensorRT

Original Colab: [here](https://colab.research.google.com/drive/1oe_aflRfCwRehho_8QlD8YFKUQ_I5sA6#scrollTo=tgSrJdHic-Du)

Colab-torch2trt: [styler00dollar/Colab-torch2trt](https://github.com/styler00dollar/Colab-torch2trt/blob/main/Colab-torch2trt.ipynb)

onnx-tensorrt: [onnx/onnx-tensorrt](https://github.com/onnx/onnx-tensorrt)


TensorRT gives better performance. Quick testing showed 12x better performance. You need to get 2 files.
Currently, the cuda version inside Colab is 11.0, that's why you need to get:

```
nv-tensorrt-repo-ubuntu1804-cuda11.0-trt7.2.3.4-ga-20210226_1-1_amd64.deb

and
 
TensorRT-7.2.3.4.Ubuntu-18.04.x86_64-gnu.cuda-11.0.cudnn8.1.tar\TensorRT-7.2.3.4\python\tensorrt-7.2.3.4-cp37-none-linux_x86_64.whl (inside TensorRT-7.2.3.4.Ubuntu-18.04.x86_64-gnu.cuda-11.0.cudnn8.1.tar.gz)
```

You can download these files [here](https://developer.nvidia.com/nvidia-tensorrt-download). Warning: You need an account (which can be created for free).

If you want to use other versions, you need to adjust the install script.

In [None]:
!nvcc --version

In [None]:
#@title Mount Google Drive
from google.colab import drive
drive.mount('/content/drive',force_remount=True)

In [None]:
#@title install
import os
os.environ["os1"]="ubuntu1804"
os.environ["tag"]= "cuda11.0-trt7.2.3.4-ga-20210226" #@param
os.environ["version"]= "7.2.3-1+cuda11.0" #@param
data_path = '/content/drive/MyDrive/tensorrt 11.0/' #@param
os.chdir(data_path)
!sudo dpkg -i nv-tensorrt-repo-${os1}-${tag}_1-1_amd64.deb
!sudo apt-key add /var/nv-tensorrt-repo-${tag}/7fa2af80.pub
!sudo apt-get update
!sudo apt-get install libnvinfer7=${version} libnvonnxparsers7=${version} libnvparsers7=${version} libnvinfer-plugin7=${version} libnvinfer-dev=${version} libnvonnxparsers-dev=${version} libnvparsers-dev=${version} libnvinfer-plugin-dev=${version} python-libnvinfer=${version} python3-libnvinfer=${version}
!sudo apt-mark hold libnvinfer7 libnvonnxparsers7 libnvparsers7 libnvinfer-plugin7 libnvinfer-dev libnvonnxparsers-dev libnvparsers-dev libnvinfer-plugin-dev python-libnvinfer python3-libnvinfer
!sudo apt-get install tensorrt=${version}
!sudo apt-get install python3-libnvinfer-dev=${version}

**Restart colab (Runtime > Restart Runtime)**

In [None]:
!pip install "/content/drive/MyDrive/tensorrt 11.0/tensorrt-7.2.3.4-cp37-none-linux_x86_64.whl"

In [None]:
#@title Convert to onnx
%cd /content/
# setting needed params, norm_type will be forced to None
num_classes = 2 #@param
net_name = 'efficientnet-b0' #@param
model = EfficientNet.from_name(net_name, num_classes=num_classes)
model.set_swish(memory_efficient=False)

from torch.autograd import Variable

import torch.onnx
import torchvision
import torch

dummy_input = Variable(torch.randn(1, 3, 256, 256)) # don't set it too high, will run out of RAM
model_path = '/content/test.pth' #@param
state_dict = torch.load(model_path)

model.load_state_dict(state_dict)
torch.onnx.export(model, dummy_input, "output.onnx")

In [None]:
#@title Install
!pip install pycuda
!pip install onnx
%cd /content/
!git clone https://github.com/onnx/onnx-tensorrt
%cd /content/onnx-tensorrt/
!python setup.py install

In [None]:
#@title Example usage of tensorrt backend
import onnx
import onnx_tensorrt.backend as backend
import numpy as np
onnx_path = "/content/output.onnx" #@param
model = onnx.load(onnx_path)
engine = backend.prepare(model, device='CUDA:0', fp16_mode=True)
input_data = np.random.random(size=(1, 3, 256, 256)).astype(np.float32)
output_data = engine.run(input_data)[0]
print(output_data)
print(output_data.shape)

In [None]:
#@title sort files with predictions (configured for 2 classes)
import torch
import glob
from efficientnet_pytorch import EfficientNet
import cv2
import torch.nn.functional as F
import shutil
import os
from tqdm import tqdm
import onnx
import onnx_tensorrt.backend as backend
import numpy as np
onnx_path = "/content/output.onnx" #@param
model = onnx.load(onnx_path)
engine = backend.prepare(model, device='CUDA:0', fp16_mode=True)




rootdir = '/content/data' #@param {type:"string"}

# probably depends on how categories are alphabetically sorted
path0 = '/content/output/0' #@param {type:"string"}
path1 = '/content/output/1' #@param {type:"string"}

if not os.path.exists(path0):
    os.makedirs(path0)
if not os.path.exists(path1):
    os.makedirs(path1)

files = glob.glob(rootdir + '/**/*.png', recursive=True)
files_jpg = glob.glob(rootdir + '/**/*.jpg', recursive=True)
files.extend(files_jpg)



height_min = 256
width_min = 256


def crop_center(img,cropx,cropy):
    y,x,_ = img.shape
    startx = x//2-(cropx//2)
    starty = y//2-(cropy//2)    
    return img[starty:starty+cropy,startx:startx+cropx]

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=1) # only difference

with torch.no_grad():
  for f in tqdm(files):
      image = cv2.imread(f)
      #image = cv2.resize(image, (256,256))

      # resizing to match original training, or detections will be bad
      height = image.shape[0]
      width = image.shape[1]
      if height > height_min and width > width_min:
          height_resized = height_min
          if width < height:
            scale_x = width_min/width
            width_resized = width_min
            height_resized = scale_x * height
          else:
            scale_y = height_min/height
            height_resized = height_min
            width_resized = scale_y * width
            image = cv2.resize(image, (round(width_resized), round(height_resized)))
      elif height <= height_min or width <= width_min:
          if height > width:
              width_resized = width_min
              scale = width_min/width
              height_resized = height*scale
              image = cv2.resize(image, (round(width_resized), round(height_resized)))
          else:
              height_resized = height_min
              scale = height_min/height
              width_resized = width*scale
              image = cv2.resize(image, (round(width_resized), round(height_resized)))

      image = crop_center(image, width_min, height_min)
      image = np.swapaxes(image,0,2)
      image = np.expand_dims(image, axis=0)
      image = image.astype(dtype=np.float32)

      y_pred = engine.run(image)[0]

      #y_pred = torch.softmax(y_pred, dim=1).float()
      y_pred = softmax(y_pred)
      top_pred = np.argmax(y_pred, axis=1).astype(np.uint8)

      if top_pred == 0:
        shutil.move(f, os.path.join(path0, os.path.basename(f)))
      elif top_pred == 1:
        shutil.move(f, os.path.join(path1, os.path.basename(f)))


----------------------------------

# Model Compression Through Teacher-Student Knowledge Distillation

After training a normal model it is possible to train a smaller model, where a bigger model is the teacher. Smaller models should reach a performance near the Teacher while being much smaller. The teacher will contribute to the loss insdie the training loop. [Here a nice link for more information](https://towardsdatascience.com/model-distillation-and-compression-for-recommender-systems-in-pytorch-5d81c0f2c0ec).

The below example edits EfficientNet to reach a size of 709kb, while the original has 20.48mb. The below code assumes a teacher with b0. The other cells are identical.

Warning: Reload notebook to reload EfficientNet import!

In [None]:
#@title model.py (adding efficientnet-s0)
%%writefile /usr/local/lib/python3.7/dist-packages/efficientnet_pytorch/model.py
"""model.py - Model and module class for EfficientNet.
   They are built to mirror those in the official TensorFlow implementation.
"""

# Author: lukemelas (github username)
# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
# With adjustments and added comments by workingcoder (github username).

import torch
from torch import nn
from torch.nn import functional as F
from .utils import (
    round_filters,
    round_repeats,
    drop_connect,
    get_same_padding_conv2d,
    get_model_params,
    efficientnet_params,
    load_pretrained_weights,
    Swish,
    MemoryEfficientSwish,
    calculate_output_image_size
)


VALID_MODELS = (
    'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3',
    'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7',
    'efficientnet-b8', 'efficientnet-s0',

    # Support the construction of 'efficientnet-l2' without pretrained weights
    'efficientnet-l2'
)


class MBConvBlock(nn.Module):
    """Mobile Inverted Residual Bottleneck Block.

    Args:
        block_args (namedtuple): BlockArgs, defined in utils.py.
        global_params (namedtuple): GlobalParam, defined in utils.py.
        image_size (tuple or list): [image_height, image_width].

    References:
        [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
        [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
        [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
    """

    def __init__(self, block_args, global_params, image_size=None):
        super().__init__()
        self._block_args = block_args
        self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
        self._bn_eps = global_params.batch_norm_epsilon
        self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
        self.id_skip = block_args.id_skip  # whether to use skip connection and drop connect

        # Expansion phase (Inverted Bottleneck)
        inp = self._block_args.input_filters  # number of input channels
        oup = self._block_args.input_filters * self._block_args.expand_ratio  # number of output channels
        if self._block_args.expand_ratio != 1:
            Conv2d = get_same_padding_conv2d(image_size=image_size)
            self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
            self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
            # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size

        # Depthwise convolution phase
        k = self._block_args.kernel_size
        s = self._block_args.stride
        Conv2d = get_same_padding_conv2d(image_size=image_size)
        self._depthwise_conv = Conv2d(
            in_channels=oup, out_channels=oup, groups=oup,  # groups makes it depthwise
            kernel_size=k, stride=s, bias=False)
        self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
        image_size = calculate_output_image_size(image_size, s)

        # Squeeze and Excitation layer, if desired
        if self.has_se:
            Conv2d = get_same_padding_conv2d(image_size=(1, 1))
            num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
            self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
            self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)

        # Pointwise convolution phase
        final_oup = self._block_args.output_filters
        Conv2d = get_same_padding_conv2d(image_size=image_size)
        self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
        self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
        self._swish = MemoryEfficientSwish()

    def forward(self, inputs, drop_connect_rate=None):
        """MBConvBlock's forward function.

        Args:
            inputs (tensor): Input tensor.
            drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).

        Returns:
            Output of this block after processing.
        """

        # Expansion and Depthwise Convolution
        x = inputs
        if self._block_args.expand_ratio != 1:
            x = self._expand_conv(inputs)
            x = self._bn0(x)
            x = self._swish(x)

        x = self._depthwise_conv(x)
        x = self._bn1(x)
        x = self._swish(x)

        # Squeeze and Excitation
        if self.has_se:
            x_squeezed = F.adaptive_avg_pool2d(x, 1)
            x_squeezed = self._se_reduce(x_squeezed)
            x_squeezed = self._swish(x_squeezed)
            x_squeezed = self._se_expand(x_squeezed)
            x = torch.sigmoid(x_squeezed) * x

        # Pointwise Convolution
        x = self._project_conv(x)
        x = self._bn2(x)

        # Skip connection and drop connect
        input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
        if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
            # The combination of skip connection and drop connect brings about stochastic depth.
            if drop_connect_rate:
                x = drop_connect(x, p=drop_connect_rate, training=self.training)
            x = x + inputs  # skip connection
        return x

    def set_swish(self, memory_efficient=True):
        """Sets swish function as memory efficient (for training) or standard (for export).

        Args:
            memory_efficient (bool): Whether to use memory-efficient version of swish.
        """
        self._swish = MemoryEfficientSwish() if memory_efficient else Swish()


class EfficientNet(nn.Module):
    """EfficientNet model.
       Most easily loaded with the .from_name or .from_pretrained methods.

    Args:
        blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
        global_params (namedtuple): A set of GlobalParams shared between blocks.

    References:
        [1] https://arxiv.org/abs/1905.11946 (EfficientNet)

    Example:
        
        
        import torch
        >>> from efficientnet.model import EfficientNet
        >>> inputs = torch.rand(1, 3, 224, 224)
        >>> model = EfficientNet.from_pretrained('efficientnet-b0')
        >>> model.eval()
        >>> outputs = model(inputs)
    """

    def __init__(self, blocks_args=None, global_params=None):
        super().__init__()
        assert isinstance(blocks_args, list), 'blocks_args should be a list'
        assert len(blocks_args) > 0, 'block args must be greater than 0'
        self._global_params = global_params
        self._blocks_args = blocks_args

        # Batch norm parameters
        bn_mom = 1 - self._global_params.batch_norm_momentum
        bn_eps = self._global_params.batch_norm_epsilon

        # Get stem static or dynamic convolution depending on image size
        image_size = global_params.image_size
        Conv2d = get_same_padding_conv2d(image_size=image_size)

        # Stem
        in_channels = 3  # rgb
        out_channels = round_filters(32, self._global_params)  # number of output channels
        self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
        self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
        image_size = calculate_output_image_size(image_size, 2)

        # Build blocks
        self._blocks = nn.ModuleList([])
        for block_args in self._blocks_args:

            # Update block input and output filters based on depth multiplier.
            block_args = block_args._replace(
                input_filters=round_filters(block_args.input_filters, self._global_params),
                output_filters=round_filters(block_args.output_filters, self._global_params),
                num_repeat=round_repeats(block_args.num_repeat, self._global_params)
            )

            # The first block needs to take care of stride and filter size increase.
            self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
            image_size = calculate_output_image_size(image_size, block_args.stride)
            if block_args.num_repeat > 1: # modify block_args to keep same output size
                block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
            for _ in range(block_args.num_repeat - 1):
                self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
                # image_size = calculate_output_image_size(image_size, block_args.stride)  # stride = 1

        # Head
        in_channels = block_args.output_filters  # output of final block
        out_channels = round_filters(1280, self._global_params)
        Conv2d = get_same_padding_conv2d(image_size=image_size)
        self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)

        # Final linear layer
        self._avg_pooling = nn.AdaptiveAvgPool2d(1)
        self._dropout = nn.Dropout(self._global_params.dropout_rate)
        self._fc = nn.Linear(out_channels, self._global_params.num_classes)
        self._swish = MemoryEfficientSwish()

    def set_swish(self, memory_efficient=True):
        """Sets swish function as memory efficient (for training) or standard (for export).

        Args:
            memory_efficient (bool): Whether to use memory-efficient version of swish.

        """
        self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
        for block in self._blocks:
            block.set_swish(memory_efficient)

    def extract_endpoints(self, inputs):
        """Use convolution layer to extract features
        from reduction levels i in [1, 2, 3, 4, 5].

        Args:
            inputs (tensor): Input tensor.

        Returns:
            Dictionary of last intermediate features
            with reduction levels i in [1, 2, 3, 4, 5].
            Example:
                >>> import torch
                >>> from efficientnet.model import EfficientNet
                >>> inputs = torch.rand(1, 3, 224, 224)
                >>> model = EfficientNet.from_pretrained('efficientnet-b0')
                >>> endpoints = model.extract_endpoints(inputs)
                >>> print(endpoints['reduction_1'].shape)  # torch.Size([1, 16, 112, 112])
                >>> print(endpoints['reduction_2'].shape)  # torch.Size([1, 24, 56, 56])
                >>> print(endpoints['reduction_3'].shape)  # torch.Size([1, 40, 28, 28])
                >>> print(endpoints['reduction_4'].shape)  # torch.Size([1, 112, 14, 14])
                >>> print(endpoints['reduction_5'].shape)  # torch.Size([1, 1280, 7, 7])
        """
        endpoints = dict()

        # Stem
        x = self._swish(self._bn0(self._conv_stem(inputs)))
        prev_x = x

        # Blocks
        for idx, block in enumerate(self._blocks):
            drop_connect_rate = self._global_params.drop_connect_rate
            if drop_connect_rate:
                drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
            x = block(x, drop_connect_rate=drop_connect_rate)
            if prev_x.size(2) > x.size(2):
                endpoints['reduction_{}'.format(len(endpoints)+1)] = prev_x
            prev_x = x

        # Head
        x = self._swish(self._bn1(self._conv_head(x)))
        endpoints['reduction_{}'.format(len(endpoints)+1)] = x

        return endpoints

    def extract_features(self, inputs):
        """use convolution layer to extract feature .

        Args:
            inputs (tensor): Input tensor.

        Returns:
            Output of the final convolution
            layer in the efficientnet model.
        """
        # Stem
        x = self._swish(self._bn0(self._conv_stem(inputs)))

        # Blocks
        for idx, block in enumerate(self._blocks):
            drop_connect_rate = self._global_params.drop_connect_rate
            if drop_connect_rate:
                drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
            x = block(x, drop_connect_rate=drop_connect_rate)

        # Head
        x = self._swish(self._bn1(self._conv_head(x)))

        return x

    def forward(self, inputs):
        """EfficientNet's forward function.
           Calls extract_features to extract features, applies final linear layer, and returns logits.

        Args:
            inputs (tensor): Input tensor.

        Returns:
            Output of this model after processing.
        """
        # Convolution layers
        x = self.extract_features(inputs)
        # Pooling and final linear layer
        x = self._avg_pooling(x)
        if self._global_params.include_top:
            x = x.flatten(start_dim=1)
            x = self._dropout(x)
            x = self._fc(x)
        return x

    @classmethod
    def from_name(cls, model_name, in_channels=3, **override_params):
        """create an efficientnet model according to name.

        Args:
            model_name (str): Name for efficientnet.
            in_channels (int): Input data's channel number.
            override_params (other key word params):
                Params to override model's global_params.
                Optional key:
                    'width_coefficient', 'depth_coefficient',
                    'image_size', 'dropout_rate',
                    'num_classes', 'batch_norm_momentum',
                    'batch_norm_epsilon', 'drop_connect_rate',
                    'depth_divisor', 'min_depth'

        Returns:
            An efficientnet model.
        """
        cls._check_model_name_is_valid(model_name)
        blocks_args, global_params = get_model_params(model_name, override_params)
        model = cls(blocks_args, global_params)
        model._change_in_channels(in_channels)
        return model

    @classmethod
    def from_pretrained(cls, model_name, weights_path=None, advprop=False,
                        in_channels=3, num_classes=1000, **override_params):
        """create an efficientnet model according to name.

        Args:
            model_name (str): Name for efficientnet.
            weights_path (None or str):
                str: path to pretrained weights file on the local disk.
                None: use pretrained weights downloaded from the Internet.
            advprop (bool):
                Whether to load pretrained weights
                trained with advprop (valid when weights_path is None).
            in_channels (int): Input data's channel number.
            num_classes (int):
                Number of categories for classification.
                It controls the output size for final linear layer.
            override_params (other key word params):
                Params to override model's global_params.
                Optional key:
                    'width_coefficient', 'depth_coefficient',
                    'image_size', 'dropout_rate',
                    'batch_norm_momentum',
                    'batch_norm_epsilon', 'drop_connect_rate',
                    'depth_divisor', 'min_depth'

        Returns:
            A pretrained efficientnet model.
        """
        model = cls.from_name(model_name, num_classes=num_classes, **override_params)
        load_pretrained_weights(model, model_name, weights_path=weights_path, load_fc=(num_classes == 1000), advprop=advprop)
        model._change_in_channels(in_channels)
        return model

    @classmethod
    def get_image_size(cls, model_name):
        """Get the input image size for a given efficientnet model.

        Args:
            model_name (str): Name for efficientnet.

        Returns:
            Input image size (resolution).
        """
        cls._check_model_name_is_valid(model_name)
        _, _, res, _ = efficientnet_params(model_name)
        return res

    @classmethod
    def _check_model_name_is_valid(cls, model_name):
        """Validates model name.

        Args:
            model_name (str): Name for efficientnet.

        Returns:
            bool: Is a valid name or not.
        """
        if model_name not in VALID_MODELS:
            raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS))

    def _change_in_channels(self, in_channels):
        """Adjust model's first convolution layer to in_channels, if in_channels not equals 3.

        Args:
            in_channels (int): Input data's channel number.
        """
        if in_channels != 3:
            Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
            out_channels = round_filters(32, self._global_params)
            self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)


In [None]:
#@title custom sized EfficientNet (edit params_dict) (efficientnet-s0)
%%writefile /usr/local/lib/python3.7/dist-packages/efficientnet_pytorch/utils.py
"""utils.py - Helper functions for building the model and for loading model parameters.
   These helper functions are built to mirror those in the official TensorFlow implementation.
"""

# Author: lukemelas (github username)
# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
# With adjustments and added comments by workingcoder (github username).

import re
import math
import collections
from functools import partial
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import model_zoo


################################################################################
### Help functions for model architecture
################################################################################

# GlobalParams and BlockArgs: Two namedtuples
# Swish and MemoryEfficientSwish: Two implementations of the method
# round_filters and round_repeats:
#     Functions to calculate params for scaling model width and depth ! ! !
# get_width_and_height_from_size and calculate_output_image_size
# drop_connect: A structural design
# get_same_padding_conv2d:
#     Conv2dDynamicSamePadding
#     Conv2dStaticSamePadding
# get_same_padding_maxPool2d:
#     MaxPool2dDynamicSamePadding
#     MaxPool2dStaticSamePadding
#     It's an additional function, not used in EfficientNet,
#     but can be used in other model (such as EfficientDet).

# Parameters for the entire model (stem, all blocks, and head)
GlobalParams = collections.namedtuple('GlobalParams', [
    'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate',
    'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon',
    'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top'])

# Parameters for an individual model block
BlockArgs = collections.namedtuple('BlockArgs', [
    'num_repeat', 'kernel_size', 'stride', 'expand_ratio',
    'input_filters', 'output_filters', 'se_ratio', 'id_skip'])

# Set GlobalParams and BlockArgs's defaults
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)


# An ordinary implementation of Swish function
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


# A memory-efficient implementation of Swish function
class SwishImplementation(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * torch.sigmoid(i)
        ctx.save_for_backward(i)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_tensors[0]
        sigmoid_i = torch.sigmoid(i)
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))

class MemoryEfficientSwish(nn.Module):
    def forward(self, x):
        return SwishImplementation.apply(x)


def round_filters(filters, global_params):
    """Calculate and round number of filters based on width multiplier.
       Use width_coefficient, depth_divisor and min_depth of global_params.

    Args:
        filters (int): Filters number to be calculated.
        global_params (namedtuple): Global params of the model.

    Returns:
        new_filters: New filters number after calculating.
    """
    multiplier = global_params.width_coefficient
    if not multiplier:
        return filters
    # TODO: modify the params names.
    #       maybe the names (width_divisor,min_width)
    #       are more suitable than (depth_divisor,min_depth).
    divisor = global_params.depth_divisor
    min_depth = global_params.min_depth
    filters *= multiplier
    min_depth = min_depth or divisor # pay attention to this line when using min_depth
    # follow the formula transferred from official TensorFlow implementation
    new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
    if new_filters < 0.9 * filters: # prevent rounding by more than 10%
        new_filters += divisor
    return int(new_filters)


def round_repeats(repeats, global_params):
    """Calculate module's repeat number of a block based on depth multiplier.
       Use depth_coefficient of global_params.

    Args:
        repeats (int): num_repeat to be calculated.
        global_params (namedtuple): Global params of the model.

    Returns:
        new repeat: New repeat number after calculating.
    """
    multiplier = global_params.depth_coefficient
    if not multiplier:
        return repeats
    # follow the formula transferred from official TensorFlow implementation
    return int(math.ceil(multiplier * repeats))


def drop_connect(inputs, p, training):
    """Drop connect.

    Args:
        input (tensor: BCWH): Input of this structure.
        p (float: 0.0~1.0): Probability of drop connection.
        training (bool): The running mode.

    Returns:
        output: Output after drop connection.
    """
    assert 0 <= p <= 1, 'p must be in range of [0,1]'

    if not training:
        return inputs

    batch_size = inputs.shape[0]
    keep_prob = 1 - p

    # generate binary_tensor mask according to probability (p for 0, 1-p for 1)
    random_tensor = keep_prob
    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
    binary_tensor = torch.floor(random_tensor)

    output = inputs / keep_prob * binary_tensor
    return output


def get_width_and_height_from_size(x):
    """Obtain height and width from x.

    Args:
        x (int, tuple or list): Data size.

    Returns:
        size: A tuple or list (H,W).
    """
    if isinstance(x, int):
        return x, x
    if isinstance(x, list) or isinstance(x, tuple):
        return x
    else:
        raise TypeError()


def calculate_output_image_size(input_image_size, stride):
    """Calculates the output image size when using Conv2dSamePadding with a stride.
       Necessary for static padding. Thanks to mannatsingh for pointing this out.

    Args:
        input_image_size (int, tuple or list): Size of input image.
        stride (int, tuple or list): Conv2d operation's stride.

    Returns:
        output_image_size: A list [H,W].
    """
    if input_image_size is None:
        return None
    image_height, image_width = get_width_and_height_from_size(input_image_size)
    stride = stride if isinstance(stride, int) else stride[0]
    image_height = int(math.ceil(image_height / stride))
    image_width = int(math.ceil(image_width / stride))
    return [image_height, image_width]


# Note:
# The following 'SamePadding' functions make output size equal ceil(input size/stride).
# Only when stride equals 1, can the output size be the same as input size.
# Don't be confused by their function names ! ! !

def get_same_padding_conv2d(image_size=None):
    """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
       Static padding is necessary for ONNX exporting of models.

    Args:
        image_size (int or tuple): Size of the image.

    Returns:
        Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
    """
    if image_size is None:
        return Conv2dDynamicSamePadding
    else:
        return partial(Conv2dStaticSamePadding, image_size=image_size)


class Conv2dDynamicSamePadding(nn.Conv2d):
    """2D Convolutions like TensorFlow, for a dynamic image size.
       The padding is operated in forward function by calculating dynamically.
    """

    # Tips for 'SAME' mode padding.
    #     Given the following:
    #         i: width or height
    #         s: stride
    #         k: kernel size
    #         d: dilation
    #         p: padding
    #     Output after Conv2d:
    #         o = floor((i+p-((k-1)*d+1))/s+1)
    # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
    # => p = (i-1)*s+((k-1)*d+1)-i

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
        super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
        self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2

    def forward(self, x):
        ih, iw = x.size()[-2:]
        kh, kw = self.weight.size()[-2:]
        sh, sw = self.stride
        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! !
        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
        return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)


class Conv2dStaticSamePadding(nn.Conv2d):
    """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
       The padding mudule is calculated in construction function, then used in forward.
    """

    # With the same calculation as Conv2dDynamicSamePadding

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs):
        super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
        self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2

        # Calculate padding based on image size and save it
        assert image_size is not None
        ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
        kh, kw = self.weight.size()[-2:]
        sh, sw = self.stride
        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
        if pad_h > 0 or pad_w > 0:
            self.static_padding = nn.ZeroPad2d((pad_w - pad_w // 2, pad_w - pad_w // 2,
                                                pad_h - pad_h // 2, pad_h - pad_h // 2))
        else:
            self.static_padding = nn.Identity()

    def forward(self, x):
        x = self.static_padding(x)
        x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
        return x


def get_same_padding_maxPool2d(image_size=None):
    """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
       Static padding is necessary for ONNX exporting of models.

    Args:
        image_size (int or tuple): Size of the image.

    Returns:
        MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
    """
    if image_size is None:
        return MaxPool2dDynamicSamePadding
    else:
        return partial(MaxPool2dStaticSamePadding, image_size=image_size)


class MaxPool2dDynamicSamePadding(nn.MaxPool2d):
    """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size.
       The padding is operated in forward function by calculating dynamically.
    """

    def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False):
        super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode)
        self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
        self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
        self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation

    def forward(self, x):
        ih, iw = x.size()[-2:]
        kh, kw = self.kernel_size
        sh, sw = self.stride
        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
        return F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
                            self.dilation, self.ceil_mode, self.return_indices)

class MaxPool2dStaticSamePadding(nn.MaxPool2d):
    """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size.
       The padding mudule is calculated in construction function, then used in forward.
    """

    def __init__(self, kernel_size, stride, image_size=None, **kwargs):
        super().__init__(kernel_size, stride, **kwargs)
        self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
        self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
        self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation

        # Calculate padding based on image size and save it
        assert image_size is not None
        ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
        kh, kw = self.kernel_size
        sh, sw = self.stride
        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
        if pad_h > 0 or pad_w > 0:
            self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
        else:
            self.static_padding = nn.Identity()

    def forward(self, x):
        x = self.static_padding(x)
        x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
                         self.dilation, self.ceil_mode, self.return_indices)
        return x


################################################################################
### Helper functions for loading model params
################################################################################

# BlockDecoder: A Class for encoding and decoding BlockArgs
# efficientnet_params: A function to query compound coefficient
# get_model_params and efficientnet:
#     Functions to get BlockArgs and GlobalParams for efficientnet
# url_map and url_map_advprop: Dicts of url_map for pretrained weights
# load_pretrained_weights: A function to load pretrained weights

class BlockDecoder(object):
    """Block Decoder for readability,
       straight from the official TensorFlow repository.
    """

    @staticmethod
    def _decode_block_string(block_string):
        """Get a block through a string notation of arguments.

        Args:
            block_string (str): A string notation of arguments.
                                Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.

        Returns:
            BlockArgs: The namedtuple defined at the top of this file.
        """
        assert isinstance(block_string, str)

        ops = block_string.split('_')
        options = {}
        for op in ops:
            splits = re.split(r'(\d.*)', op)
            if len(splits) >= 2:
                key, value = splits[:2]
                options[key] = value

        # Check stride
        assert (('s' in options and len(options['s']) == 1) or
                (len(options['s']) == 2 and options['s'][0] == options['s'][1]))

        return BlockArgs(
            num_repeat=int(options['r']),
            kernel_size=int(options['k']),
            stride=[int(options['s'][0])],
            expand_ratio=int(options['e']),
            input_filters=int(options['i']),
            output_filters=int(options['o']),
            se_ratio=float(options['se']) if 'se' in options else None,
            id_skip=('noskip' not in block_string))

    @staticmethod
    def _encode_block_string(block):
        """Encode a block to a string.

        Args:
            block (namedtuple): A BlockArgs type argument.

        Returns:
            block_string: A String form of BlockArgs.
        """
        args = [
            'r%d' % block.num_repeat,
            'k%d' % block.kernel_size,
            's%d%d' % (block.strides[0], block.strides[1]),
            'e%s' % block.expand_ratio,
            'i%d' % block.input_filters,
            'o%d' % block.output_filters
        ]
        if 0 < block.se_ratio <= 1:
            args.append('se%s' % block.se_ratio)
        if block.id_skip is False:
            args.append('noskip')
        return '_'.join(args)

    @staticmethod
    def decode(string_list):
        """Decode a list of string notations to specify blocks inside the network.

        Args:
            string_list (list[str]): A list of strings, each string is a notation of block.

        Returns:
            blocks_args: A list of BlockArgs namedtuples of block args.
        """
        assert isinstance(string_list, list)
        blocks_args = []
        for block_string in string_list:
            blocks_args.append(BlockDecoder._decode_block_string(block_string))
        return blocks_args

    @staticmethod
    def encode(blocks_args):
        """Encode a list of BlockArgs to a list of strings.

        Args:
            blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.

        Returns:
            block_strings: A list of strings, each string is a notation of block.
        """
        block_strings = []
        for block in blocks_args:
            block_strings.append(BlockDecoder._encode_block_string(block))
        return block_strings


def efficientnet_params(model_name):
    """Map EfficientNet model name to parameter coefficients.

    Args:
        model_name (str): Model name to be queried.

    Returns:
        params_dict[model_name]: A (width,depth,res,dropout) tuple.
    """
    params_dict = {
        # Coefficients:   width,depth,res,dropout
        'efficientnet-b0': (1.0, 1.0, 224, 0.2),
        'efficientnet-b1': (1.0, 1.1, 240, 0.2),
        'efficientnet-b2': (1.1, 1.2, 260, 0.3),
        'efficientnet-b3': (1.2, 1.4, 300, 0.3),
        'efficientnet-b4': (1.4, 1.8, 380, 0.4),
        'efficientnet-b5': (1.6, 2.2, 456, 0.4),
        'efficientnet-b6': (1.8, 2.6, 528, 0.5),
        'efficientnet-b7': (2.0, 3.1, 600, 0.5),
        'efficientnet-b8': (2.2, 3.6, 672, 0.5),
        'efficientnet-l2': (4.3, 5.3, 800, 0.5),
        'efficientnet-s0': (0.1, 0.1, 224, 0.2),
    }
    return params_dict[model_name]


def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None,
                 dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True):
    """Create BlockArgs and GlobalParams for efficientnet model.

    Args:
        width_coefficient (float)
        depth_coefficient (float)
        image_size (int)
        dropout_rate (float)
        drop_connect_rate (float)
        num_classes (int)

        Meaning as the name suggests.

    Returns:
        blocks_args, global_params.
    """

    # Blocks args for the whole model(efficientnet-b0 by default)
    # It will be modified in the construction of EfficientNet Class according to model
    blocks_args = [
        'r1_k3_s11_e1_i32_o16_se0.25',
        'r2_k3_s22_e6_i16_o24_se0.25',
        'r2_k5_s22_e6_i24_o40_se0.25',
        'r3_k3_s22_e6_i40_o80_se0.25',
        'r3_k5_s11_e6_i80_o112_se0.25',
        'r4_k5_s22_e6_i112_o192_se0.25',
        'r1_k3_s11_e6_i192_o320_se0.25',
    ]
    blocks_args = BlockDecoder.decode(blocks_args)

    global_params = GlobalParams(
        width_coefficient=width_coefficient,
        depth_coefficient=depth_coefficient,
        image_size=image_size,
        dropout_rate=dropout_rate,

        num_classes=num_classes,
        batch_norm_momentum=0.99,
        batch_norm_epsilon=1e-3,
        drop_connect_rate=drop_connect_rate,
        depth_divisor=8,
        min_depth=None,
        include_top=include_top,
    )

    return blocks_args, global_params


def get_model_params(model_name, override_params):
    """Get the block args and global params for a given model name.

    Args:
        model_name (str): Model's name.
        override_params (dict): A dict to modify global_params.

    Returns:
        blocks_args, global_params
    """
    if model_name.startswith('efficientnet'):
        w, d, s, p = efficientnet_params(model_name)
        # note: all models have drop connect rate = 0.2
        blocks_args, global_params = efficientnet(
            width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
    else:
        raise NotImplementedError('model name is not pre-defined: {}'.format(model_name))
    if override_params:
        # ValueError will be raised here if override_params has fields not included in global_params.
        global_params = global_params._replace(**override_params)
    return blocks_args, global_params


# train with Standard methods
# check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks)
url_map = {
    'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
    'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
    'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
    'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
    'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
    'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
    'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
    'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
}

# train with Adversarial Examples(AdvProp)
# check more details in paper(Adversarial Examples Improve Image Recognition)
url_map_advprop = {
    'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth',
    'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth',
    'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth',
    'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth',
    'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth',
    'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth',
    'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth',
    'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth',
    'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth',
}

# TODO: add the petrained weights url map of 'efficientnet-l2'


def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False):
    """Loads pretrained weights from weights path or download using url.

    Args:
        model (Module): The whole model of efficientnet.
        model_name (str): Model name of efficientnet.
        weights_path (None or str):
            str: path to pretrained weights file on the local disk.
            None: use pretrained weights downloaded from the Internet.
        load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
        advprop (bool): Whether to load pretrained weights
                        trained with advprop (valid when weights_path is None).
    """
    if isinstance(weights_path, str):
        state_dict = torch.load(weights_path)
    else:
        # AutoAugment or Advprop (different preprocessing)
        url_map_ = url_map_advprop if advprop else url_map
        state_dict = model_zoo.load_url(url_map_[model_name])

    if load_fc:
        ret = model.load_state_dict(state_dict, strict=False)
        assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
    else:
        state_dict.pop('_fc.weight')
        state_dict.pop('_fc.bias')
        ret = model.load_state_dict(state_dict, strict=False)
        assert set(ret.missing_keys) == set(
            ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
    assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys)

    print('Loaded pretrained weights for {}'.format(model_name))


In [None]:
#@title CustomTrainClass.py
from efficientnet_pytorch import EfficientNet
from adamp import AdamP
#from adamp import SGDP
import numpy as np

class CustomTrainClass(pl.LightningModule):
  def __init__(self):
    super().__init__()
    num_classes = 2 #@param

    self.netD = EfficientNet.from_name('efficientnet-s0')
    teacher_path = '/content/test_b0_official.pth' #@param
    self.teacher = EfficientNet.from_pretrained('efficientnet-b0', num_classes=num_classes)
    self.teacher.load_state_dict(torch.load(teacher_path))

    #weights_init(self.netD, 'kaiming') #only use this if there is no pretrain

    self.criterion = GridMixupLoss(
        alpha=(0.4, 0.7),
        hole_aspect_ratio=1.,
        crop_area_ratio=(0.5, 1),
        crop_aspect_ratio=(0.5, 2),
        n_holes_x=(2, 6)
    )
    self.accuracy = []
    self.losses = []
    self.diffaug_activate = False #@param

    #@markdown Supports ``'color,translation,cutout'``
    self.policy = 'color' #@param

  def training_step(self, train_batch, batch_idx):
    inputs, targets = self.criterion.get_sample(images=train_batch[0], targets=train_batch[1].unsqueeze(-1))  
    targets = targets -1 # fixing range

    if self.diffaug_activate == False:
      preds = self.netD(inputs)
    else:
      preds = self.netD(DiffAugment(inputs, policy=self.policy))

    # Calculate loss
    loss = self.criterion(preds, targets)

    # teacher loss 
    if self.diffaug_activate == False:
      preds_teacher = self.teacher(inputs)
    else:
      preds_teacher = self.teacher(DiffAugment(inputs, policy=self.policy))
    loss += self.criterion(preds_teacher, targets)

    self.accuracy.append(calc_accuracy(preds, targets))
    self.losses.append(loss.item())
    return loss  

  def configure_optimizers(self):
      #optimizer = torch.optim.Adam(self.netD.parameters(), lr=2e-3)
      optimizer = AdamP(self.netD.parameters(), lr=2e-4, betas=(0.9, 0.999), weight_decay=1e-2)
      #optimizer = SGDP(self.netD.parameters(), lr=0.1, weight_decay=1e-5, momentum=0.9, nesterov=True)
      return optimizer

  def training_epoch_end(self, training_step_outputs):
      loss_mean = np.mean(self.losses)
      accuracy_mean = np.mean(self.accuracy)
      print(f"'Epoch': {self.current_epoch}, 'loss': {loss_mean}, 'accuracy': {accuracy_mean}")
      
      # logging
      self.log('loss_mean', loss_mean, prog_bar=True, logger=True, on_epoch=True)
      self.log('accuracy_mean', accuracy_mean, prog_bar=True, logger=True, on_epoch=True)

      self.losses = []
      self.accuracy = []

      torch.save(trainer.model.netD.state_dict(), f"Checkpoint_{self.current_epoch}_{self.global_step}_loss_{loss_mean:3f}_acc_{accuracy_mean:3f}_D.pth")

  def validation_step(self, train_batch, train_idx):
      print("not implemented")
  def test_step(self, train_batch, train_idx):
      print("not implemented")