In [16]:
import argparse
import cv2
import numpy as np
import os
from tqdm import tqdm
import torch
from basicsr.archs.ddcolor_arch import DDColor
import torch.nn.functional as F

In [17]:
class ImageColorizationPipeline(object):

    def __init__(self, model_path, input_size=256, model_size='large'):
        
        self.input_size = input_size
        if torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')

        if model_size == 'tiny':
            self.encoder_name = 'convnext-t'
        else:
            self.encoder_name = 'convnext-l'

        self.decoder_type = "MultiScaleColorDecoder"

        if self.decoder_type == 'MultiScaleColorDecoder':
            self.model = DDColor(
                encoder_name=self.encoder_name,
                decoder_name='MultiScaleColorDecoder',
                input_size=[self.input_size, self.input_size],
                num_output_channels=2,
                last_norm='Spectral',
                do_normalize=False,
                num_queries=100,
                num_scales=3,
                dec_layers=9,
            ).to(self.device)
        else:
            self.model = DDColor(
                encoder_name=self.encoder_name,
                decoder_name='SingleColorDecoder',
                input_size=[self.input_size, self.input_size],
                num_output_channels=2,
                last_norm='Spectral',
                do_normalize=False,
                num_queries=256,
            ).to(self.device)

        self.model.load_state_dict(
            torch.load(model_path, map_location=torch.device('cpu'))['params'],
            strict=False)
        self.model.eval()

    @torch.no_grad()
    def process(self, img):
        self.height, self.width = img.shape[:2]
        # print(self.width, self.height)
        # if self.width * self.height < 100000:
        #     self.input_size = 256

        img = (img / 255.0).astype(np.float32)
        orig_l = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)[:, :, :1]  # (h, w, 1)

        # resize rgb image -> lab -> get grey -> rgb
        img = cv2.resize(img, (self.input_size, self.input_size))
        img_l = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)[:, :, :1]
        img_gray_lab = np.concatenate((img_l, np.zeros_like(img_l), np.zeros_like(img_l)), axis=-1)
        img_gray_rgb = cv2.cvtColor(img_gray_lab, cv2.COLOR_LAB2RGB)

        tensor_gray_rgb = torch.from_numpy(img_gray_rgb.transpose((2, 0, 1))).float().unsqueeze(0).to(self.device)
        output_ab = self.model(tensor_gray_rgb).cpu()  # (1, 2, self.height, self.width)

        # resize ab -> concat original l -> rgb
        output_ab_resize = F.interpolate(output_ab, size=(self.height, self.width))[0].float().numpy().transpose(1, 2, 0)
        output_lab = np.concatenate((orig_l, output_ab_resize), axis=-1)
        output_bgr = cv2.cvtColor(output_lab, cv2.COLOR_LAB2BGR)

        output_img = (output_bgr * 255.0).round().astype(np.uint8)    

        return output_img


In [18]:
def extract_frames(video_path):
    print(f"Extracting frames from Video...")
    cap = cv2.VideoCapture(video_path)
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)
    cap.release()
    return frames

def process_frames(frames, colorizer):
    print(f"Processing Frames from video")
    processed_frames = []
    for frame in frames:
        processed_frame = colorizer.process(frame)
        processed_frames.append(processed_frame)
    return processed_frames

def create_video(processed_frames, output_path, fps=30):
    print(f"Creating the video")
    height, width, _ = processed_frames[0].shape
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    for frame in processed_frames:
        video.write(frame)
    video.release()

In [19]:
def main():
    model_path = '/home/raghuram/DDColor/modelscope/damo/cv_ddcolor_image-colorization/pytorch_model.pt'
    input_path = 'DDColor/assets/test_images'
    output_path = 'results'
    input_size = 512
    model_size = 'large'
    video_output_path = '/home/raghuram/ddcolor_dl_final/colorize_output'

    print(f'Output path: {output_path}')

    os.makedirs(output_path, exist_ok=True)

    colorizer = ImageColorizationPipeline(model_path=model_path, input_size=input_size, model_size=model_size)

    # if os.path.isdir(input_path):
    #     img_list = os.listdir(input_path)
    #     assert len(img_list) > 0
    #     for name in tqdm(img_list):
    #         img = cv2.imread(os.path.join(input_path, name))
    #         image_out = colorizer.process(img)
    #         cv2.imwrite(os.path.join(output_path, name), image_out)
    # else:
    #     frames = extract_frames(input_path)
    #     processed_frames = process_frames(frames, colorizer)
    #     create_video(processed_frames, os.path.join(video_output_path, 'output_video.mp4'))

In [14]:
if __name__ == '__main__':
    main()

Output path: results
