In [2]:
import extract

In [None]:
class FSRCNN(nn.Module):
    def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4):
        super(FSRCNN, self).__init__()
        # feature extraction
        self.first_part = nn.Sequential(
            nn.Conv2d(num_channels, d, kernel_size=5, padding=5//2),
            nn.PReLU(d)
        )
        # shrinking
        self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU(s)]
        # mapping
        for _ in range(m):
            self.mid_part.extend([nn.Conv2d(s, s, kernel_size=3, padding=3//2), nn.PReLU(s)])
        # expanding
        self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d)])
        self.mid_part = nn.Sequential(*self.mid_part)
        # Deconvolution
        self.last_part = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor, padding=9//2,
                                            output_padding=scale_factor-1)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.first_part:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        for m in self.mid_part:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001)
        nn.init.zeros_(self.last_part.bias.data)

    def forward(self, x):
        x = self.first_part(x)
        x = self.mid_part(x)
        x = self.last_part(x)
        return x

In [None]:
class Trainer():
    def __init__(self):
        # устройство для обучения
        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

        # количество шагов обучения
        self.n_steps = 10000

        # раз в сколько шагов выводить результаты
        self.print_interval = 25

        # раз в сколько шагов сохранять чекпоинт
        self.save_interval = 2500

        self.batch_size = 24
        self.workers = 8

        # инициализация модели
        self.fsrcnn = FSRCNN().to(self.device)

        # конфигурация оптимизатора Adam
        self.optimizer = Adam(
            self.srcnn.parameters(),
            0.0001
        )

        # функция потерь MSE
        self.pixel_criterion = nn.MSELoss().to(self.device)

        # разрешение hr изображения в формате (h, w)
        self.size = (480, 856)
        self.crop = (384, 384)

        # аугментации для обучения и валидации
        train_transform = SameTransform(self.size, 'train', crop=self.crop)

        # путь где хранятся папки lr и hr с изображениями
        train_prefix = './train_frames'

        # train датасет
        trainset = SRDataset(
            f'{train_prefix}/lr',
            f'{train_prefix}/hr',
            train_transform
        )

        # даталоадер для обучения батчами
        self.trainloader = DataLoader(
            trainset,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=self.workers,
            pin_memory=True
        )

        # аугментации для инференса
        self.resize = transforms.Resize(self.size, antialias=None)
        self.np2tensor = transforms.ToTensor()

    def train_step(self, lr, hr):
        g_hr = self.srcnn(lr)
        loss = self.pixel_criterion(g_hr, hr)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def train(self):
        self.srcnn.train()
        step = 0

        while True:
            if step >= self.n_steps:
                break

            for batch in self.trainloader:
                lr, hr = batch
                lr = lr.to(self.device, non_blocking=True)
                hr = hr.to(self.device, non_blocking=True)

                mse = self.train_step(lr, hr)
                step += 1

                if step % self.print_interval == 0:
                    print(f'STEP={step} MSE={mse:.5f}')

    def frame2tensor(self, img):
        rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        hr = self.resize(self.np2tensor(rgb))
        return hr

    def tensor2frame(self, img):
        nparr = (img.detach().cpu().numpy() * 255).astype(np.uint8)
        nparr = np.transpose(nparr, (1, 2, 0))
        bgr = cv2.cvtColor(nparr, cv2.COLOR_RGB2BGR)
        return bgr

    def super_resolution(self, input_video, output_video):
        self.srcnn.eval()

        cap = cv2.VideoCapture(input_video)
        fps = cap.get(cv2.CAP_PROP_FPS)

        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        writer = cv2.VideoWriter(
            output_video,
            fourcc,
            fps,
            (self.size[1], self.size[0])
        )

        while True:
            success, frame = cap.read()
            if not success:
                break

            tensor = self.frame2tensor(frame).to(self.device)
            with torch.no_grad():
                output_tensor = self.srcnn(tensor)
            output_frame = self.tensor2frame(output_tensor)

            writer.write(output_frame)

        cap.release()
        writer.release()