<a href="https://colab.research.google.com/github/vincentbriat/Super-resolution-investigation/blob/main/Knowledge_distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from os import listdir, environ, path
import sys
from pathlib import Path
import pandas as pd
from PIL import Image
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
import cv2 as cv 
from sklearn.decomposition import PCA
import numpy as np
from google.colab import drive
import random as rd
import csv
from torchsummary import summary
import matplotlib.pyplot as plt

!pip install pytorch-msssim
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM

drive.mount('/content/drive/')

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch-msssim
  Downloading pytorch_msssim-0.2.1-py3-none-any.whl (7.2 kB)
Installing collected packages: pytorch-msssim
Successfully installed pytorch-msssim-0.2.1
Mounted at /content/drive/


In [2]:
# Making sure to use the gpu, if available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# torch.cuda.set_device(torch.device(0))

In [16]:
class ConvBlock(nn.Module):
  def __init__(self, in_channels: int, out_channels: int, kernel_size: int, use_act: bool, **kwargs):
    super().__init__()
    self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, **kwargs)
    self.activation = nn.LeakyReLU(.2, inplace=True) if use_act else nn.Identity()

  def forward(self, x):
    return self.activation(self.conv(x))

class RDB(nn.Module):
  def __init__(self, in_channels, middle_channels = 32, residual_scale = .2):
    super().__init__()
    self.residual_scale = residual_scale
    self.block = nn.ModuleList([ConvBlock(in_channels + i * middle_channels,
                                  middle_channels if i<4 else in_channels,
                                  3,
                                  stride=1,
                                  padding=1,
                                  use_act=i<4) for i in range(5)])


  def forward(self, x):
    input = x
    for conv in self.block:
      out = conv(input)
      input = torch.cat([input, out], dim=1)
    return self.residual_scale * out + x

class RRDB(nn.Module):
  def __init__(self, in_channels, residual_scale = .2):
    super().__init__()
    self.residual_scale = residual_scale
    self.model = nn.Sequential(*[RDB(in_channels, residual_scale = residual_scale) for _ in range(3)])
  
  def forward(self, x):
    return self.model(x) * self.residual_scale + x

class Head(nn.Module):
  def __init__(self) -> None:
    super().__init__()
    self.model = nn.Conv2d(3, 64, 3, stride=1, padding=1)

  def forward(self, x):
    return self.model(x)

class Tail(nn.Module):
  def __init__(self) -> None:
    super().__init__()
    self.model = nn.Sequential(nn.Conv2d(64, 256, 3, stride=1, padding=1),
                               nn.Upsample(scale_factor=3, mode='nearest'),
                               nn.LeakyReLU(.2, inplace=True),
                               nn.Conv2d(256, 3, 3, stride=1, padding=1))
  
  def forward(self, x):
    return self.model(x)

In [31]:
class Student(nn.Module):
  def __init__(self):
    super().__init__()
    self.head = Head()
    self.bodies = nn.ModuleList([RRDB(64) for _ in range(16)])
    self.post_body_conv = nn.Conv2d(64, 64, 3, stride=1, padding=1)
    self.tail = Tail()
    self.loss_fn = nn.L1Loss()
    self.opt = torch.optim.SGD(self.parameters(), lr=0.1)
    self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=10, gamma=0.5)
    self.nb_of_bodies = 1
  
  def forward(self, x: torch.Tensor):
    output = self.head(x)
    for i in range(self.nb_of_bodies):
      output = self.bodies[i](output)
    output = self.head(x) + self.post_body_conv(output)
    return self.tail(output)

  def train(self, output_generator: torch.Tensor, output_teacher: torch.Tensor):
    super().train()
    output_student = self(output_generator)
    return self.learn(output_student, output_teacher)
    
  def learn(self, output_student: torch.Tensor, output_teacher: torch.Tensor):
    loss = self.loss_fn(output_student, output_teacher)
    self.opt.zero_grad()
    loss.backward()
    self.opt.step()
    return loss

  def initialize_csv(self, path, loss):
    self.record(path, ['epoch', 'loss'])
    self.record(path, [self.cur_epoch, loss])

  def record(self, path, row):
    with open(path, 'a', newline='') as f:
      csv.writer(f).writerow(row)
    print("Recorded Successfully!")
  
  def load(self, path):
    saved_info = torch.load(path)
    self.load_state_dict(saved_info['model_state_dict'])
    self.opt.load_state_dict(saved_info['optimizer_state_dict'])
    self.scheduler.load_state_dict(saved_info['lr_scheduler_state_dict'])
    self.cur_epoch = saved_info['epochs'] + 1
    self.loss_fn = saved_info['loss_fn']

  def save(self,path):
    torch.save({'epochs': self.cur_epoch, 'model_state_dict': self.state_dict(), 'optimizer_state_dict': self.opt.state_dict(), 'loss_fn': self.loss_fn, 'lr_scheduler_state_dict': self.scheduler.state_dict()}, path)
    print("Saved Successfully!")

In [23]:
class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    self.model = nn.Sequential(nn.Conv2d(1,64,3), nn.Conv2d(64, 3, 5))
    self.loss_fn = nn.L1Loss()
    self.opt = torch.optim.SGD(self.parameters(), lr=1e-5)
    self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=10, gamma=0.5)
    self.weight_r = 1.
  
  def forward(self, x: torch.Tensor):
    return self.model(x)

  def train(self, output_generator: torch.Tensor, output_teacher: torch.Tensor, output_student: torch.Tensor):
    super().train()

    loss_gen = -torch.log(self.loss_fn(output_student, output_teacher) + 1)
    loss_r = self.loss_fn(output_generator, torchvision.transforms.Resize(output_generator.shape[-2:])(output_teacher))

    loss = loss_gen + self.weight_r * loss_r

    self.opt.zero_grad()
    loss.backward()
    self.opt.step()
    return loss

In [33]:
gene = Generator()
student = Student()
teacher = Student()
student.nb_of_bodies = 16
summary(student, (3,48,48), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 48, 48]           1,792
              Head-2           [-1, 64, 48, 48]               0
            Conv2d-3           [-1, 32, 48, 48]          18,464
         LeakyReLU-4           [-1, 32, 48, 48]               0
         ConvBlock-5           [-1, 32, 48, 48]               0
            Conv2d-6           [-1, 32, 48, 48]          27,680
         LeakyReLU-7           [-1, 32, 48, 48]               0
         ConvBlock-8           [-1, 32, 48, 48]               0
            Conv2d-9           [-1, 32, 48, 48]          36,896
        LeakyReLU-10           [-1, 32, 48, 48]               0
        ConvBlock-11           [-1, 32, 48, 48]               0
           Conv2d-12           [-1, 32, 48, 48]          46,112
        LeakyReLU-13           [-1, 32, 48, 48]               0
        ConvBlock-14           [-1, 32,

In [None]:
#teacher.requires_grad_(False)
#gene.requires_grad_(False)
input = torch.rand((1,480,480))
out = gene(input)
#display(torchvision.transforms.ToPILImage()(out))

In [None]:
student.train(out.detach(), teacher(out).detach())

tensor(0.1146, grad_fn=<MeanBackward0>)

In [None]:
class Knowledge_Distillation:
  def __init__(self, student: nn.Module, teacher: nn.Module, generator: nn.Module, iterations: int = 120, imitation_steps: int = 50, height: int = 500, width: int = 500, batch_size: int = 2):
    # Different bodies for the progressive training
    self.student = student
    self.teacher = teacher
    self.generator = generator
    self.cur_epoch = 0
    self.imitation_steps = imitation_steps
    self.batch_size = batch_size
    self.height = height
    self.width = width
    self.iterations = iterations
    
    self.teacher.requires_grad_(False)

    # The various S_i are as follows S_tail(Body_i...Body_0(S_head))
  
  def train(self, epochs: int):

    for self.cur_epoch in range(epochs):
      print(f'\n*********Epoch {self.cur_epoch}/{epochs}*********\n')
      
      for iteration in range(self.iterations):
        print(f'Iteration {iteration}/{self.iterations}')

        # Student training

        self.generator.requires_grad_(False)

        for k in range(self.imitation_steps):
          print(f'Imitation step {k}/{self.imitation_steps}')

          noise_images = torch.rand((self.batch_size, 1, self.height, self.width))
          generated_images = self.generator(noise_images)
          teacher_output = self.teacher(generated_images)
          loss = self.student.train(generated_images, teacher_output)
          print(loss.item())

        self.student.scheduler.step()
        self.generator.requires_grad_()
        
        # Generator training

        print('\nGenerator training')
        self.student.requires_grad_(False)

        noise_images = torch.rand((self.batch_size, 1, self.height, self.width))
        generated_images = self.generator(noise_images)
        teacher_output = self.teacher(generated_images)
        student_output = self.student(generated_images)

        loss = self.generator.train(generated_images, teacher_output, student_output)
        print(loss.item())
        
        self.generator.scheduler.step()
        self.student.requires_grad_()


In [None]:
kd = Knowledge_Distillation(student, teacher, gene)
kd.train(10)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Imitation step 49/50
0.003474239259958267
Generator training
0.11734476685523987
Iteration 25/120
Imitation step 0/50
0.003500130493193865
Imitation step 1/50
0.0035392974968999624
Imitation step 2/50
0.003648364683613181
Imitation step 3/50
0.0038301863241940737
Imitation step 4/50
0.004091516602784395
Imitation step 5/50
0.004283817484974861
Imitation step 6/50
0.004409704357385635
Imitation step 7/50
0.004378161393105984
Imitation step 8/50
0.004375381860882044
Imitation step 9/50
0.0042854174971580505
Imitation step 10/50
0.004244538489729166
Imitation step 11/50
0.004154119174927473
Imitation step 12/50
0.004132249392569065
Imitation step 13/50
0.004068807233124971
Imitation step 14/50
0.0040396819822490215
Imitation step 15/50
0.003986785188317299
Imitation step 16/50
0.003961618524044752
Imitation step 17/50
0.0038912200834602118
Imitation step 18/50
0.0038797773886471987
Imitation step 19/50
0.003833209630101919
I