#Single-Image Super-Resolution for satellite imaging
##Deep Learning Course - MVA 2020-2021
####Quentin Spinat & Thomas Chabal

## Load drive

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

%cd /content/drive/MyDrive/Super_Resolution_DL2020/Python/

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).
/content/drive/MyDrive/Super_Resolution_DL2020/Python


## Visualize dataset

In [None]:
from utils.visualizer import Visualizer
from common.constants import DATA_ROOT
from common.dataset import SatelliteDataset
from common.transforms import create_transforms

train_transforms, test_transforms = create_transforms()
train_dataset = SatelliteDataset(DATA_ROOT, train_transforms, is_training_set=True)
test_dataset = SatelliteDataset(DATA_ROOT, test_transforms, is_training_set=False)

visualizer = Visualizer()

print('Training example')
visualizer.visualize_sample(train_dataset[0])

print('Test example')
visualizer.visualize_sample(test_dataset[0])

## Classical Computer Vision evaluations

We compute evaluations of linear and spline interpolations looking at the L2 loss on the training set.

In [None]:
from classical_cv.cv2_resizer import Cv2ResizerEvaluation

Cv2ResizerEvaluation().evaluate_dataset(train_dataset)

In [None]:
Cv2ResizerEvaluation().evaluate_dataset(test_dataset)

In [None]:
from classical_cv.spline import SplineEvaluation

SplineEvaluation(order=5).evaluate_dataset(train_dataset)

In [None]:
SplineEvaluation(order=5).evaluate_dataset(test_dataset)

## AUTOENCODER PART

In [None]:
%cd /content/drive/MyDrive/Super_Resolution_DL2020/Python/

from utils.visualizer import Visualizer
from skimage import transform
from common.dataset import SatelliteDataset
from common.transforms import create_transforms, create_transforms_runet
from autoencoder.AE_preprocess import patch_decomp, patch_recomp
import collections

from autoencoder.AE_model import AE,CDA,mapping
from autoencoder.AE_train import train_AE, train_mapping, fine_tuning

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
use_cuda = torch.cuda.is_available()
if use_cuda :
    device=torch.device("cuda")
    print("using GPU")
else :
    device=torch.device("cpu")
    print("using CPU")




DATA_ROOT = "/content/drive/MyDrive/road_segmentation_ideal/"

train_transforms, test_transforms = create_transforms(128,128)

train_dataset = SatelliteDataset(DATA_ROOT, train_transforms, is_training_set=True)
test_dataset = SatelliteDataset(DATA_ROOT, test_transforms, is_training_set=False)

### Test decomposition en patch et recomposition

In [None]:
img = train_dataset[0]["image"].numpy()
img_big = transform.resize(img, (img.shape[0],2*img.shape[1], 2*img.shape[2]))

patches = patch_decomp(img_big)
print(patches.shape)

img_recomp = patch_recomp(patches,img_shape=img_big.shape)
print(img_recomp.shape)

plt.figure()
plt.imshow(img.transpose((1,2,0)))

plt.figure()
plt.imshow(img_big.transpose((1,2,0)))

plt.figure()
plt.imshow(patches[0].transpose((1,2,0)))

plt.figure()
plt.imshow(img_recomp.transpose((1,2,0)))

### Loading the dataset into torch

In [None]:
# C'est moche mais pas le choix

patch_img=[]
patch_label=[]

N_img = 800
N_patch_by_fig = 1000

for i in range(N_img):
  print(i,end=' ')
  sample = train_dataset[i]
  img = sample["image"].numpy()
  label = sample["label"].numpy()
  sample = torch.FloatTensor(patch_decomp(transform.resize(img, label.shape)))
  random_elements = np.random.choice(np.arange(sample.shape[0]),N_patch_by_fig,replace=False)
  sample = sample[random_elements]
  label = torch.FloatTensor(patch_decomp(label))
  label = label[random_elements]
  patch_img.append(sample)
  patch_label.append(label)

patch_img = torch.cat(patch_img, dim=0).to(device)
patch_label = torch.cat(patch_label, dim=0).to(device)

print()
print(patch_img.shape)

0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 27

In [None]:
bs = 256
num_epochs = 10

#step1
print("LR autoencoder training")
LR_AE = AE().to(device)
train_dataset = torch.utils.data.TensorDataset(patch_img,patch_img)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=bs)
for epoch in range(1,num_epochs+1):
  train_AE(LR_AE,train_loader,epoch,log_interval=1000)
torch.save(LR_AE.state_dict(), "AE_LR_L2.pth")

#step2
print("HR autoencoder training")
HR_AE = AE().to(device)
train_dataset = torch.utils.data.TensorDataset(patch_label,patch_label)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=bs)
for epoch in range(1,num_epochs+1):
  train_AE(HR_AE,train_loader,epoch,log_interval=1000)
torch.save(HR_AE.state_dict(), "AE_HR_L2.pth")


LR autoencoder training
HR autoencoder training


In [None]:
#step3
print("mapping training")
model_mapping = mapping().to(device)
train_dataset = torch.utils.data.TensorDataset(patch_img,patch_label)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=bs)
for epoch in range(1,num_epochs+1):
  train_mapping(LR_AE,model_mapping,HR_AE,train_loader,epoch,log_interval=1000)
torch.save(model_mapping.state_dict(), "AE_mapping_L2.pth")


mapping training


In [None]:
#step4
model_CDA = CDA().to(device)
model_CDA.load_state_dict(collections.OrderedDict([('enc.weight',LR_AE.enc.weight),
                                           ('enc.bias',LR_AE.enc.bias),
                                           ('map.weight',model_mapping.map.weight),
                                           ('map.bias',model_mapping.map.bias),
                                           ('dec.weight',HR_AE.dec.weight),
                                           ('dec.bias',HR_AE.dec.bias)
                                                ]))

for epoch in range(1,num_epochs+1):
  fine_tuning(model_CDA,train_loader,epoch,lr=0.001,log_interval=1000)
torch.save(model_CDA.state_dict(), "AE_CDA_L2.pth")



### Test du réseau entrainé

In [None]:
DATA_ROOT = "/content/drive/MyDrive/road_segmentation_ideal/"

from autoencoder.AE_model import CDA
from autoencoder.visu import visu
from skimage import transform
import matplotlib.pyplot as plt
import torch
use_cuda = torch.cuda.is_available()
if use_cuda :
    device=torch.device("cuda")
    print("using GPU")
else :
    device=torch.device("cpu")
    print("using CPU")

model_CDA = CDA().to(device)
model_CDA.load_state_dict(torch.load("AE_CDA_L2.pth"))
model_CDA.eval()

train_transforms, test_transforms = create_transforms(128,128)

train_dataset = SatelliteDataset(DATA_ROOT, train_transforms, is_training_set=True)
test_dataset = SatelliteDataset(DATA_ROOT, test_transforms, is_training_set=False)

visu(model_CDA,test_dataset,8)

In [None]:
from autoencoder.evaluate import test_model

res = test_model(model_CDA,test_dataset)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))


Mean PSNR of 25.446 on test set with std of 1.704
Mean SSIM of 0.92243 on test set with std of 0.03976
Mean MSE of 0.0018783 on test set with std of 0.0005505
Mean VGG-Perceptual of 5.28254 on test set with std of 0.74915


## GAN SUPER-RESOLUTION

In [None]:
!git clone https://github.com/idealo/image-super-resolution.git
%cd /content/image-super-resolution/
!python setup.py install
%cd /content/drive/MyDrive/Super_Resolution_DL2020/Python/

### Visu GANs

In [None]:
from rrdn_gans.visualizer import GANsVisualizer

visualizer_gans = GANsVisualizer()
visualizer_gans.visualize_gans(batch_size=8)

### Evaluate quantitatively the RRDN

In [None]:
from rrdn_gans.evaluate import RRDNEvaluation

evaluation_rrdn = RRDNEvaluation(patch_size=128)
evaluation_rrdn.evaluate()

## RUNET

### Training

In [None]:
from runet.main import train_runet

train_runet(img_size=128, train_bs=32, test_bs=1, lr=0.001)

### Visualize the results of RUNet

In [None]:
from runet.visualize import RUNetVisualizer

checkpoint_unet = "checkpoints/perceptual_loss_RUNET_var_blur.pth"

visualizer_runet = RUNetVisualizer()
visualizer_runet.visualize_runet(checkpoint_unet, batch_size=8)

### Evaluate quantitatively the RUNet

In [None]:
from runet.evaluate import RUNetEvaluation

checkpoint_unet = "checkpoints/perceptual_loss_RUNET_var_blur.pth"
evaluator_runet = RUNetEvaluation()
evaluator_runet.evaluate(checkpoint_unet)

## General visualization

In [None]:
import torch
from ISR.models import RRDN

from common.constants import DATA_ROOT
from common.dataset import SatelliteDataset
from common.transforms import create_transforms
from runet.runet import RUNet
from global_visualization import visualize_all_models


train_transforms, test_transforms = create_transforms(128, 128)
train_dataset = SatelliteDataset(DATA_ROOT, train_transforms, is_training_set=True)
test_dataset = SatelliteDataset(DATA_ROOT, test_transforms, is_training_set=False)

model_RRDN = RRDN(weights="gans", patch_size=64)

model_CDA = CDA().to(device)
model_CDA.load_state_dict(torch.load("AE_CDA_L2.pth"))
model_CDA.eval()

model_RUNET = RUNet().to(device)
model_RUNET.load_state_dict(torch.load("checkpoints/perceptual_loss_RUNET_var_blur.pth"))
model_RUNET.eval()

visualize_all_models(model_CDA, model_RUNET, model_RRDN, test_dataset, num_images=1, image_id=0)