In [1]:
import extract

In [2]:
import os
import random
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset
import cv2
from torchvision import transforms
import torchvision.transforms.functional as TF
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm
import math

from skimage.metrics import peak_signal_noise_ratio,  structural_similarity

In [3]:
class FSRCNN(nn.Module):
    def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4):
        super(FSRCNN, self).__init__()

        self.feature_extraction = nn.Sequential(
            nn.Conv2d(num_channels, d, kernel_size=5, padding=5//2),
            nn.PReLU(d)
        )

        self.shrinking = nn.Sequential(
            nn.Conv2d(d, s, kernel_size=1), 
            nn.PReLU(s)
        )

        mapping = []
        for _ in range(m):
            mapping.extend([nn.Conv2d(s, s, kernel_size=3, padding=3//2), nn.PReLU(s)])
        self.mapping = nn.Sequential(*mapping)

        self.expanding = nn.Sequential(nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d))
        
        # originally found d instead of s in picture
        self.deconvolution = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor, padding=9//2,
                                            output_padding=scale_factor-1)
        
        self.out = nn.Sigmoid()

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.feature_extraction:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        
        for m in self.shrinking:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        
        for m in self.mapping:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        
        for m in self.expanding:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
            
        nn.init.normal_(self.deconvolution.weight.data, mean=0.0, std=0.001)
        # nn.init.zeros_(self.last_part.bias.data)

    def forward(self, x):
        x = self.feature_extraction(x)
        x = self.shrinking(x)
        x = self.mapping(x)
        x = self.expanding(x)
        x = self.deconvolution(x)
        x = self.out(x) * 255
        return x

In [4]:
video_path = './rutube_hackaton_super_resolution_khabarovsk/train'
train_path = './train_frames'

lr_path = os.path.join(train_path, 'lr')
hr_path = os.path.join(train_path, 'hr')

if not os.path.exists(train_path):
    os.system(f'mkdir -p {train_path}')

if not os.path.exists(lr_path):
    os.system(f'mkdir -p {lr_path}')

if not os.path.exists(hr_path):
    os.system(f'mkdir -p {hr_path}')

In [5]:
files = os.listdir(video_path)
pairs = []
for f in files:
    if f.endswith('_144.mp4'):
        hr_name = f.split('_')[0] + '_480.mp4'
        pairs += [(f, hr_name)]

In [6]:
# n_frames = 5000
# size = int(n_frames // len(pairs))

# save_idx = 0
# for idx in tqdm(range(len(pairs))):
#     pair = pairs[idx]

#     lr = os.path.join(video_path, pair[0])
#     hr = os.path.join(video_path, pair[1])

#     lr_cap = cv2.VideoCapture(lr)
#     hr_cap = cv2.VideoCapture(hr)

#     lr_len = int(lr_cap.get(cv2.CAP_PROP_FRAME_COUNT))
#     hr_len = int(hr_cap.get(cv2.CAP_PROP_FRAME_COUNT))

#     assert lr_len == hr_len

#     frames_idx = [i for i in range(lr_len)]
#     if size:
#         frames_idx = np.random.choice(frames_idx, size=size, replace=False)

#     tmp_idx = 0
#     while True:
#         success_lr, frame_lr = lr_cap.read()
#         success_hr, frame_hr = hr_cap.read()
#         if not success_lr or not success_hr:
#             break
#         if tmp_idx in frames_idx:
#             lr_save_path = os.path.join(lr_path, f'{save_idx}.jpg')
#             hr_save_path = os.path.join(hr_path, f'{save_idx}.jpg')
#             cv2.imwrite(lr_save_path, frame_lr)
#             cv2.imwrite(hr_save_path, frame_hr)
#             save_idx += 1
#         tmp_idx += 1

In [7]:
class SRDataset(Dataset):
    def __init__(self, lr_path, hr_path, transform = None):
        self.lr = [os.path.join(lr_path, f) for f in os.listdir(lr_path)]
        self.hr = [os.path.join(hr_path, f) for f in os.listdir(hr_path)]
        self.lr, self.hr = sorted(self.lr), sorted(self.hr)
        assert len(self.lr) == len(self.hr)
        self.transform = transform

    def __len__(self):
        return len(self.lr)

    def file2np(self, path):
        img = cv2.imread(path)
        return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    def __getitem__(self, idx):
        lr = self.file2np(self.lr[idx])
        hr = self.file2np(self.hr[idx])
        if self.transform is not None: lr, hr = self.transform(lr, hr)
        return lr, hr


In [8]:
class SameTransform(object):
    def __init__(self, mode, crop=None):
        self.np2tensor = transforms.ToTensor()
        self.mode = mode
        self.crop = crop
        self.lr_resize = transforms.Resize((120, 214), antialias = True)

    def __call__(self, lr, hr):
        lr = self.np2tensor(lr)
        hr = self.np2tensor(hr)

        if self.mode == 'train':
            lr, hr = self.same_transform(lr, hr)
            lr = self.lr_resize(lr)

        if self.crop:
            i, j, h, w = transforms.RandomCrop.get_params(lr, self.crop)
            lr = TF.crop(lr, i, j, h, w)
            hr = TF.crop(hr, i, j, h, w)
            
        return lr, hr#np.expand_dims(lr, 0), np.expand_dims(hr, 0)
    
    # после преобразований lr и hr сохраняют пространственное соотношение
    def same_transform(self, image1, image2, p=0.5):
        if random.random() > p:
            image1 = TF.hflip(image1)
            image2 = TF.hflip(image2)

        if random.random() > p:
            image1 = TF.vflip(image1)
            image2 = TF.vflip(image2)

        return image1, image2

In [15]:
class Trainer():
    def __init__(self):
        # устройство на котором идет обучение
        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

        # количество шагов обучения
        self.n_steps = 10000

        # раз во сколько шагов выводить результаты
        self.print_interval = 25
        
        
        # раз во сколько шагов чекпоинт
        self.save_interval = 2500

        self.batch_size = 50
        self.workers = 8

        # инициализация модели
        self.fsrcnn = FSRCNN(scale_factor=4, num_channels=3).to(self.device)

        # конфигурация оптимизатора Adam
        self.optimizer = Adam(
            self.fsrcnn.parameters(),
            0.0001
        )

        # функция потерь MSE
        self.pixel_criterion = nn.MSELoss().to(self.device)

        # разрешение hr изображения в формате (h, w)
        self.size = (480, 856)
        self.gcrop = transforms.CenterCrop([480, 856])

        # # аугментации для обучения и валидации
        train_transform = SameTransform('train')

        # путь где хранятся папки lr и hr с изображениями
        train_prefix = './train_frames'

        # train датасет
        trainset = SRDataset(
            f'{train_prefix}/lr',
            f'{train_prefix}/hr',
            train_transform
        )

        # даталоадер для обучения батчами
        self.trainloader = DataLoader(
            trainset,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=self.workers,
            pin_memory=True
        )

        # аугментации для инференса
        self.resize = transforms.Resize(self.size, antialias=None)
        self.np2tensor = transforms.ToTensor()

    def train_step(self, lr, hr):
        g_hr = self.fsrcnn(lr)
        g_hr = self.gcrop.forward(g_hr)
        loss = self.pixel_criterion(g_hr, hr)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def train(self):
        self.fsrcnn.train()
        step = 0

        while True:
            if step >= self.n_steps:
                break

            for batch in self.trainloader:
                lr, hr = batch
                lr = lr.to(self.device, non_blocking=True)
                hr = hr.to(self.device, non_blocking=True)

                mse = self.train_step(lr, hr)
                step += 1

                if step % self.print_interval == 0:
                    print(f'STEP={step} MSE={mse:.5f}')
                    
                if step % self.save_interval == 0:
                    torch.save({
                        'epoch': step,
                        'model_state_dict': self.fsrcnn.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        'loss': mse,
                        }, "./checkpoint")
        torch.save({
                'epoch': step,
                'model_state_dict': self.fsrcnn.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': mse,
                }, "./checkpoint")

    def frame2tensor(self, img):
        rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        hr = self.np2tensor(rgb)#self.resize(self.np2tensor(rgb))
        return hr

    def tensor2frame(self, img):
        nparr = (img.detach().cpu().numpy() * 255).astype(np.uint8)
        nparr = np.transpose(nparr, (1, 2, 0))
        bgr = cv2.cvtColor(nparr, cv2.COLOR_RGB2BGR)
        return bgr

    def super_resolution(self, input_video, output_video, test_video = None):
        crop = transforms.CenterCrop(self.size)
        self.fsrcnn.eval()

        cap = cv2.VideoCapture(input_video)
        fps = cap.get(cv2.CAP_PROP_FPS)

        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        writer = cv2.VideoWriter(
            output_video,
            fourcc,
            fps,
            (self.size[1], self.size[0])
        )
        
        resize_lr = transforms.Resize((120, 214), antialias = True)
        
        if test_video: 
            test_cap = cv2.VideoCapture(test_video)
            psnr = []
            ssim = []

        while True:
            success, frame = cap.read()
            if test_video: 
                t_success, test_frame = test_cap.read()
                success = success and t_success
            
            if not success: break
            
            if test_video:
                psnr.append(peak_signal_noise_ratio(test_frame, frame))
                ssim.append(structural_similarity(test_frame, frame))
            
            tensor = self.frame2tensor(frame).to(self.device).unsqueeze_(0)#lr_crop.forward(self.frame2tensor(frame).to(self.device)).unsqueeze_(0)
            tensor = resize_lr(tensor)
            with torch.no_grad(): 
                output_tensor = self.fsrcnn(tensor)
            output_frame = self.tensor2frame(crop.forward(output_tensor[0]))

            writer.write(output_frame)

        cap.release()
        writer.release()
        
        if test_video: 
            test_cap.release()
            print(f"Average PSNR for output video and ground truth is {np.array(psnr).mean()}")
            print(f"Average SSIM for output video and ground truth is {np.array(ssim).mean()}")

In [10]:
# создаем объект - trainer для запуска процесса обучения и инференса
trainer = Trainer()

In [11]:
# запускаем процесс обучения
trainer.train()

STEP=25 MSE=13369.08691
STEP=50 MSE=4490.10889
STEP=75 MSE=1117.88025
STEP=100 MSE=536.88116
STEP=125 MSE=452.49768
STEP=150 MSE=184.67203
STEP=175 MSE=180.98439
STEP=200 MSE=150.79094
STEP=225 MSE=68.04465
STEP=250 MSE=84.89474
STEP=275 MSE=85.66084
STEP=300 MSE=43.99466
STEP=325 MSE=55.67273
STEP=350 MSE=28.66559
STEP=375 MSE=33.55159
STEP=400 MSE=31.96677
STEP=425 MSE=24.69353
STEP=450 MSE=19.99586
STEP=475 MSE=14.54031
STEP=500 MSE=27.95485
STEP=525 MSE=14.71871
STEP=550 MSE=23.31065
STEP=575 MSE=13.05248
STEP=600 MSE=10.56397
STEP=625 MSE=12.19871
STEP=650 MSE=7.06052
STEP=675 MSE=10.53920
STEP=700 MSE=6.03039
STEP=725 MSE=11.15774
STEP=750 MSE=7.13961
STEP=775 MSE=9.42714
STEP=800 MSE=9.89115
STEP=825 MSE=6.27144
STEP=850 MSE=7.40824
STEP=875 MSE=6.28052
STEP=900 MSE=3.05488
STEP=925 MSE=4.55955
STEP=950 MSE=3.97193
STEP=975 MSE=4.14490
STEP=1000 MSE=6.59838
STEP=1025 MSE=4.46128
STEP=1050 MSE=2.80200
STEP=1075 MSE=4.08943
STEP=1100 MSE=5.47512
STEP=1125 MSE=3.97100
STEP=1150 MSE

In [17]:
lr_video = '/home/owner/Documents/DEV/Python/SuperResolution/rutube_hackaton_super_resolution_khabarovsk/train/1_144.mp4'
out_video = '/home/owner/Documents/DEV/Python/SuperResolution/1_480_new.mp4'
hr_video = '/home/owner/Documents/DEV/Python/SuperResolution/rutube_hackaton_super_resolution_khabarovsk/train/1_480.mp4'


trainer.super_resolution(lr_video, out_video, hr_video)

Average PSNR for output video and ground truth is nan
Average SSIM for output video and ground truth is nan


  print(f"Average PSNR for output video and ground truth is {np.array(psnr).mean()}")
  ret = ret.dtype.type(ret / rcount)
  print(f"Average SSIM for output video and ground truth is {np.array(ssim).mean()}")
