In [1]:
import extract

In [1]:
import os
import random
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset
import cv2
from torchvision import transforms
import torchvision.transforms.functional as TF
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm
import math

In [2]:
class FSRCNN(nn.Module):
    def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4):
        super(FSRCNN, self).__init__()
        # feature extraction
        self.first_part = nn.Sequential(
            nn.Conv2d(num_channels, d, kernel_size=5, padding=5//2),
            nn.PReLU(d)
        )
        # shrinking
        self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU(s)]
        # mapping
        for _ in range(m):
            self.mid_part.extend([nn.Conv2d(s, s, kernel_size=3, padding=3//2), nn.PReLU(s)])
        # expanding
        self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d)])
        self.mid_part = nn.Sequential(*self.mid_part)
        # Deconvolution
        # originally found d instead of s in picture
        self.last_part = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor, padding=9//2,
                                            output_padding=scale_factor-1)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.first_part:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        for m in self.mid_part:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001)
        nn.init.zeros_(self.last_part.bias.data)

    def forward(self, x):
        x = self.first_part(x)
        x = self.mid_part(x)
        x = self.last_part(x)
        return x

In [3]:
# video_path = './rutube_hackaton_super_resolution_khabarovsk/train'
# train_path = './train_frames'

# lr_path = os.path.join(train_path, 'lr')
# hr_path = os.path.join(train_path, 'hr')

# if not os.path.exists(train_path):
#     os.system(f'mkdir -p {train_path}')

# if not os.path.exists(lr_path):
#     os.system(f'mkdir -p {lr_path}')

# if not os.path.exists(hr_path):
#     os.system(f'mkdir -p {hr_path}')

In [4]:
# files = os.listdir(video_path)
# pairs = []
# for f in files:
#     if f.endswith('_144.mp4'):
#         hr_name = f.split('_')[0] + '_480.mp4'
#         pairs += [(f, hr_name)]

In [5]:
# n_frames = 5000
# size = int(n_frames // len(pairs))

# save_idx = 0
# for idx in tqdm(range(len(pairs))):
#     pair = pairs[idx]

#     lr = os.path.join(video_path, pair[0])
#     hr = os.path.join(video_path, pair[1])

#     lr_cap = cv2.VideoCapture(lr)
#     hr_cap = cv2.VideoCapture(hr)

#     lr_len = int(lr_cap.get(cv2.CAP_PROP_FRAME_COUNT))
#     hr_len = int(hr_cap.get(cv2.CAP_PROP_FRAME_COUNT))

#     assert lr_len == hr_len

#     frames_idx = [i for i in range(lr_len)]
#     if size:
#         frames_idx = np.random.choice(frames_idx, size=size, replace=False)

#     tmp_idx = 0
#     while True:
#         success_lr, frame_lr = lr_cap.read()
#         success_hr, frame_hr = hr_cap.read()
#         if not success_lr or not success_hr:
#             break
#         if tmp_idx in frames_idx:
#             lr_save_path = os.path.join(lr_path, f'{save_idx}.jpg')
#             hr_save_path = os.path.join(hr_path, f'{save_idx}.jpg')
#             cv2.imwrite(lr_save_path, frame_lr)
#             cv2.imwrite(hr_save_path, frame_hr)
#             save_idx += 1
#         tmp_idx += 1

In [6]:
class SRDataset(Dataset):
    def __init__(self, lr_path, hr_path, transform = None):
        self.lr = [os.path.join(lr_path, f) for f in os.listdir(lr_path)]
        self.hr = [os.path.join(hr_path, f) for f in os.listdir(hr_path)]
        self.lr, self.hr = sorted(self.lr), sorted(self.hr)
        assert len(self.lr) == len(self.hr)
        self.transform = transform

    def __len__(self):
        return len(self.lr)

    def file2np(self, path):
        img = cv2.imread(path)
        return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    def __getitem__(self, idx):
        lr = self.file2np(self.lr[idx])
        hr = self.file2np(self.hr[idx])
        if self.transform is not None: lr, hr = self.transform(lr, hr)
        return lr, hr


In [7]:
class SameTransform(object):
    def __init__(self, mode, crop=None):
        self.np2tensor = transforms.ToTensor()
        self.mode = mode
        self.crop = crop
        self.lr_crop = transforms.CenterCrop([120, 214])

    def __call__(self, lr, hr):
        lr = self.lr_crop.forward(self.np2tensor(lr))
        hr = self.np2tensor(hr)

        if self.mode == 'train':
            lr, hr = self.same_transform(lr, hr)

        if self.crop:
            i, j, h, w = transforms.RandomCrop.get_params(lr, self.crop)
            lr = TF.crop(lr, i, j, h, w)
            hr = TF.crop(hr, i, j, h, w)
            
        return lr, hr
    
    # после преобразований lr и hr сохраняют пространственное соотношение
    def same_transform(self, image1, image2, p=0.5):
        if random.random() > p:
            image1 = TF.hflip(image1)
            image2 = TF.hflip(image2)

        if random.random() > p:
            image1 = TF.vflip(image1)
            image2 = TF.vflip(image2)

        return image1, image2

In [56]:
class Trainer():
    def __init__(self):
        # устройство для обучения
        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

        # количество шагов обучения
        self.n_steps = 10000

        # раз в сколько шагов выводить результаты
        self.print_interval = 25

        # раз в сколько шагов сохранять чекпоинт
        self.save_interval = 2500

        self.batch_size = 24
        self.workers = 8

        # инициализация модели
        self.fsrcnn = FSRCNN(scale_factor=4, num_channels=3).to(self.device)

        # конфигурация оптимизатора Adam
        self.optimizer = Adam(
            self.fsrcnn.parameters(),
            0.0001
        )

        # функция потерь MSE
        self.pixel_criterion = nn.MSELoss().to(self.device)

        # разрешение hr изображения в формате (h, w)
        self.size = (480, 856)
        # self.crop = (384, 384)

        # # аугментации для обучения и валидации
        train_transform = SameTransform('train')

        # путь где хранятся папки lr и hr с изображениями
        train_prefix = './train_frames'

        # train датасет
        trainset = SRDataset(
            f'{train_prefix}/lr',
            f'{train_prefix}/hr',
            train_transform
        )

        # даталоадер для обучения батчами
        self.trainloader = DataLoader(
            trainset,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=self.workers,
            pin_memory=True
        )

        # аугментации для инференса
        self.resize = transforms.Resize(self.size, antialias=None)
        self.np2tensor = transforms.ToTensor()

    def train_step(self, lr, hr):
        g_hr = self.fsrcnn(lr)
        loss = self.pixel_criterion(g_hr, hr)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def train(self):
        self.fsrcnn.train()
        step = 0

        while True:
            if step >= self.n_steps:
                break

            for batch in self.trainloader:
                lr, hr = batch
                lr = lr.to(self.device, non_blocking=True)
                hr = hr.to(self.device, non_blocking=True)

                mse = self.train_step(lr, hr)
                step += 1

                if step % self.print_interval == 0:
                    print(f'STEP={step} MSE={mse:.5f}')

    def frame2tensor(self, img):
        rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        hr = self.np2tensor(rgb)#self.resize(self.np2tensor(rgb))
        return hr

    def tensor2frame(self, img):
        nparr = (img.detach().cpu().numpy() * 255).astype(np.uint8)
        nparr = np.transpose(nparr, (1, 2, 0))
        bgr = cv2.cvtColor(nparr, cv2.COLOR_RGB2BGR)
        return bgr


    def super_resolution(self, input_video, output_video):
        lr_crop = transforms.CenterCrop([120, 214])
        self.fsrcnn.eval()

        cap = cv2.VideoCapture(input_video)
        fps = cap.get(cv2.CAP_PROP_FPS)

        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        writer = cv2.VideoWriter(
            output_video,
            fourcc,
            fps,
            (self.size[1], self.size[0])
        )

        while True:
            success, frame = cap.read()
            if not success:
                break
            transforms.ToTensor()
            tensor = lr_crop.forward(self.frame2tensor(frame).to(self.device)).unsqueeze_(0)
            with torch.no_grad():
                output_tensor = self.fsrcnn(tensor)
            output_frame = self.tensor2frame(output_tensor[0])

            writer.write(output_frame)

        cap.release()
        writer.release()

In [57]:
# создаем объект - trainer для запуска процесса обучения и инференса
trainer = Trainer()

In [58]:
# запускаем процесс обучения
trainer.train()

STEP=25 MSE=0.06122
STEP=50 MSE=0.02678
STEP=75 MSE=0.02713
STEP=100 MSE=0.02587
STEP=125 MSE=0.02690
STEP=150 MSE=0.02259
STEP=175 MSE=0.01992
STEP=200 MSE=0.01971
STEP=225 MSE=0.01843
STEP=250 MSE=0.02251
STEP=275 MSE=0.02237
STEP=300 MSE=0.02124
STEP=325 MSE=0.01713
STEP=350 MSE=0.01708
STEP=375 MSE=0.01798
STEP=400 MSE=0.01974
STEP=425 MSE=0.02305
STEP=450 MSE=0.01715
STEP=475 MSE=0.01840
STEP=500 MSE=0.02116
STEP=525 MSE=0.01597
STEP=550 MSE=0.01687
STEP=575 MSE=0.02168
STEP=600 MSE=0.01583
STEP=625 MSE=0.01905
STEP=650 MSE=0.01729
STEP=675 MSE=0.01501
STEP=700 MSE=0.01734
STEP=725 MSE=0.01690
STEP=750 MSE=0.01767
STEP=775 MSE=0.01634
STEP=800 MSE=0.01846
STEP=825 MSE=0.01797
STEP=850 MSE=0.01828
STEP=875 MSE=0.02281
STEP=900 MSE=0.01686
STEP=925 MSE=0.01620
STEP=950 MSE=0.01598
STEP=975 MSE=0.01805
STEP=1000 MSE=0.01505
STEP=1025 MSE=0.01913
STEP=1050 MSE=0.01652
STEP=1075 MSE=0.01607
STEP=1100 MSE=0.01693
STEP=1125 MSE=0.01712
STEP=1150 MSE=0.01650
STEP=1175 MSE=0.01508
STEP=120

In [59]:
lr_video = '/home/owner/Documents/DEV/Python/SuperResolution/rutube_hackaton_super_resolution_khabarovsk/train/1_144.mp4'
hr_video = '/home/owner/Documents/DEV/Python/SuperResolution/rutube_hackaton_super_resolution_khabarovsk/train/1_480_new.mp4'

trainer.super_resolution(lr_video, hr_video)

In [None]:
def super_resolution(input_video, output_video):
        lr_crop = transforms.CenterCrop([120, 214])
        trainer.fsrcnn.eval()

        cap = cv2.VideoCapture(input_video)
        fps = cap.get(cv2.CAP_PROP_FPS)

        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        writer = cv2.VideoWriter(
            output_video,
            fourcc,
            fps,
            (trainer.size[1], trainer.size[0])
        )

        while True:
            success, frame = cap.read()
            if not success:
                break
            transforms.ToTensor()
            tensor = lr_crop.forward(trainer.frame2tensor(frame).to(trainer.device)).unsqueeze_(0).repeat(trainer.batch_size)
            with torch.no_grad():
                output_tensor = trainer.fsrcnn(tensor)
            output_frame = trainer.tensor2frame(output_tensor[0])

            writer.write(output_frame)

        cap.release()
        writer.release()