In [None]:
# smooth.py
import argparse
import os

import numpy as np
from PIL import Image
from skimage.io import imsave


def lin_interp(n, p1, p2):
    x = np.zeros((n, p1.shape[0], 128, 3))

    for i in range(n):
        a = (i + 1) / (n + 1)
        x[i] = (1 - a) * p1 + a * p2

    return x


def smooth(in_img, ws):
    _name, _ext = os.path.splitext(in_img)
    out_img = f"{_name}_s{ws}{_ext}"

    in_img = np.array(Image.open(in_img)) / 255.0
    orig_img = in_img[24:-24, :1280, :]  # left image, remove borders
    in_img = in_img[:, 1280:, :]  # right image

    # 6,10,128,128,3
    patches = np.reshape(in_img, (6, 128, 10, 128, 3))
    patches = np.transpose(patches, (0, 2, 1, 3, 4))

    h = ws // 2

    for i in range(5):
        p1 = patches[i, :, 128 - h, :, :]
        p2 = patches[i + 1, :, h, :, :]

        x = lin_interp(ws, p1, p2)
        patches[i, :, 128 - h :, :, :] = np.transpose(x[:h, :, :, :], (1, 0, 2, 3))
        patches[i + 1, :, :h, :, :] = np.transpose(x[h:, :, :, :], (1, 0, 2, 3))

    for j in range(9):
        p3 = patches[:, j, :, 128 - h, :]
        p4 = patches[:, j + 1, :, h, :]

        x = lin_interp(ws, p3, p4)
        patches[:, j, :, 128 - h :, :] = np.transpose(x[:h, :, :, :], (1, 2, 0, 3))
        patches[:, j + 1, :, :h, :] = np.transpose(x[h:, :, :, :], (1, 2, 0, 3))

    out = np.transpose(patches, (0, 2, 1, 3, 4))
    out = np.reshape(out, (768, 1280, 3))
    out = out[24:-24, :, :]

    out = np.concatenate((orig_img, out), axis=1)
    imsave(out_img, out)

# if __name__ == "__main__":
#     parser = argparse.ArgumentParser()
#     parser.add_argument("--in_img", type=str, required=True)
#     parser.add_argument("--window_size", type=int, required=True)
#     args = parser.parse_args()

#     # make sure an even size is used
#     args.window_size += args.window_size % 2

#     smooth(args.in_img, args.window_size)



In [None]:
!pip install colored
!pip install numpy

Collecting colored
  Downloading colored-1.4.3.tar.gz (29 kB)
Building wheels for collected packages: colored
  Building wheel for colored (setup.py) ... [?25l[?25hdone
  Created wheel for colored: filename=colored-1.4.3-py3-none-any.whl size=14341 sha256=dbd35bf2255dfaff88e6fa4b2208fe78566323f1dd476120efe180cd48460761
  Stored in directory: /root/.cache/pip/wheels/4a/f6/00/835e81851bc345428a253721c8bdad0062721dfb861bc6e752
Successfully built colored
Installing collected packages: colored
Successfully installed colored-1.4.3


In [None]:
# logger.py
"""
Inspired by https://github.com/SebiSebi/friendlylog
"""

import logging
import sys
from copy import copy
from typing import Union

from colored import attr, fg

DEBUG = "debug"
INFO = "info"
WARNING = "warning"
ERROR = "error"
CRITICAL = "critical"

LOG_LEVELS = {
    DEBUG: logging.DEBUG,
    INFO: logging.INFO,
    WARNING: logging.WARNING,
    ERROR: logging.ERROR,
    CRITICAL: logging.CRITICAL,
}


class _Formatter(logging.Formatter):
    def __init__(self, colorize=False, *args, **kwargs):
        super(_Formatter, self).__init__(*args, **kwargs)
        self.colorize = colorize

    @staticmethod
    def _process(msg, loglevel, colorize):
        loglevel = str(loglevel).lower()
        if loglevel not in LOG_LEVELS:
            raise RuntimeError(
                f"{loglevel} should be one of {LOG_LEVELS}."
            )  # pragma: no cover

        msg = f"{str(loglevel).upper()}: {str(msg)}"

        if not colorize:
            return msg

        if loglevel == DEBUG:
            return "{}{}{}".format(fg(5), msg, attr(0))  # noqa: E501
        if loglevel == INFO:
            return "{}{}{}".format(fg(4), msg, attr(0))  # noqa: E501
        if loglevel == WARNING:
            return "{}{}{}{}{}".format(
                fg(214), attr(1), msg, attr(21), attr(0)
            )  # noqa: E501
        if loglevel == ERROR:
            return "{}{}{}{}{}".format(
                fg(202), attr(1), msg, attr(21), attr(0)
            )  # noqa: E501
        if loglevel == CRITICAL:
            return "{}{}{}{}{}".format(
                fg(196), attr(1), msg, attr(21), attr(0)
            )  # noqa: E501

    def format(self, record):
        record = copy(record)
        loglevel = record.levelname
        record.msg = _Formatter._process(record.msg, loglevel, self.colorize)
        return super(_Formatter, self).format(record)


class Logger:
    def __init__(self, name="default", colorize=False, stream=sys.stdout, level=DEBUG):
        self.name = name

        # get the logger object; keep it hidden as there's no need to directly access it
        self.__logger = logging.getLogger(f"_logger-{name}")
        self.__logger.propagate = False
        self.setLevel(level.lower())

        # use the custom formatter
        self.__formatter = _Formatter(
            colorize=colorize,
            fmt="[%(process)d][%(asctime)s.%(msecs)03d @ %(funcName)s] %(message)s",
            datefmt="%y-%m-%d %H:%M:%S",
        )

        # install default handler
        self.__stream_to_handler = {}
        self.clear_handlers()
        self.__main_handler = self.add_handler(stream)

        # install logging functions
        self.debug = self.__logger.debug
        self.info = self.__logger.info
        self.warning = self.__logger.warning
        self.error = self.__logger.error
        self.critical = self.__logger.critical

    def log_function(self):
        def wrapper(func):
            def func_wrapper(*args, **kwargs):
                self.__logger.info(
                    f"calling <{func.__name__}>\n\t  args: {args}\n\tkwargs: {kwargs}"
                )
                out = func(*args, **kwargs)
                self.__logger.info(f"exiting <{func.__name__}>")
                return out

            return func_wrapper

        return wrapper

    def setLevel(self, level: Union[str, int]) -> None:
        if isinstance(level, int):
            self.__logger.setLevel(level)
        else:
            if level.lower() not in LOG_LEVELS:
                raise ValueError(f"level should be one of {LOG_LEVELS}")
            self.__logger.setLevel(LOG_LEVELS[level.lower()])

    def add_handler(self, stream) -> logging.StreamHandler:
        handler = logging.StreamHandler(stream)
        handler.setFormatter(self.__formatter)
        self.__logger.addHandler(handler)
        self.__stream_to_handler[stream] = handler
        return handler

    def remove_handler(self, stream) -> bool:
        if stream in self.__stream_to_handler:
            self.__logger.removeHandler(self.__stream_to_handler[stream])
            self.__stream_to_handler.pop(stream)
            return True
        return False

    def clear_handlers(self) -> None:
        self.__logger.handlers = []
        self.__stream_to_handler = {}

    def get_handlers(self) -> list:
        return self.__logger.handlers

    # Don't use these unless you know what you are doing

    @property
    def inner_logger(self):
        return self.__logger

    @property
    def inner_stream_handler(self):
        return self.__main_handler

    @property
    def inner_formatter(self):
        return self.__formatter


In [None]:
# resize.sh dont need to run first
#!/bin/bash

file_name=$(basename "$1"); file_name="${file_name%.*}"
dir_name=$(dirname "$1")

convert $1 -resize 1280x720! ${dir_name}/${file_name}.bmp

SyntaxError: ignored

In [None]:
# utils.py
import struct
import numpy as np
from torchvision.utils import save_image


def save_imgs(imgs, to_size, name) -> None:
    # x = np.array(x)
    # x = np.transpose(x, (1, 2, 0)) * 255
    # x = x.astype(np.uint8)
    # imsave(name, x)

    # x = 0.5 * (x + 1)

    # to_size = (C, H, W)
    imgs = imgs.clamp(0, 1)
    imgs = imgs.view(imgs.size(0), *to_size)
    save_image(imgs, name)


def save_encoded(enc: np.ndarray, fname: str) -> None:
    enc = np.reshape(enc, -1)
    sz = str(len(enc)) + "d"

    with open(fname, "wb") as fp:
        fp.write(struct.pack(sz, *enc))


In [None]:
# data_loader.py
from pathlib import Path
from typing import Tuple

import numpy as np
import torch as T
from PIL import Image
from torch.utils.data import Dataset


class ImageFolder720p(Dataset):
    """
    Image shape is (720, 1280, 3) --> (768, 1280, 3) --> 6x10 128x128 patches
    """

    def __init__(self, root: str):
        self.files = sorted(Path(root).iterdir())

    def __getitem__(self, index: int) -> Tuple[T.Tensor, np.ndarray, str]:
        path = str(self.files[index % len(self.files)])
        img = np.array(Image.open(path))

        pad = ((24, 24), (0, 0), (0, 0))

        # img = np.pad(img, pad, 'constant', constant_values=0) / 255
        img = np.pad(img, pad, mode="edge") / 255.0

        img = np.transpose(img, (2, 0, 1))
        img = T.from_numpy(img).float()

        patches = np.reshape(img, (3, 6, 128, 10, 128))
        patches = np.transpose(patches, (0, 1, 3, 2, 4))

        return img, patches, path

    def __len__(self):
        return len(self.files)


In [None]:
# namespace.py
import json


class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update({k: self.__elt(v) for k, v in kwargs.items()})

    def __elt(self, xs):
        if isinstance(xs, dict):
            return Namespace(**xs)

        if isinstance(xs, (list, tuple)):
            return [self.__elt(x) for x in xs]

        return xs

    def __str__(self):
        return json.dumps(self.__get_nested(), indent=4)

    def __repr__(self):
        return str(self)

    def __len__(self):
        return len(self.__dict__)

    def __get_nested(self) -> dict:
        out = {}

        for k, v in self.__dict__.items():
            # nested
            if isinstance(v, Namespace):
                out[k] = "<self>" if v is self else Namespace.__get_nested(v)

            # non-primitive type, call its str method
            elif hasattr(v, "__dict__"):
                out[k] = str(v)

            # primitives
            else:
                out[k] = v

        return out

    def is_empty(self):
        return len(self) == 0

    def to_dict(self) -> dict:
        return self.__get_nested()

    def to_file(self, fname) -> None:
        with open(fname, "wt") as fp:
            fp.write(str(self))

In [None]:





# train

import os
import yaml
import argparse
from pathlib import Path

import numpy as np
import torch as T
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


# from utils import save_imgs

# from namespace import Namespace
# from logger import Logger

from models.cae_32x32x32_zero_pad_bin import CAE

logger = Logger(__name__, colorize=True)

device = T.device("cuda" if T.cuda.is_available() else "cpu")

def train(cfg):
    #assert cfg.device == "cpu" or (cfg.device == "cuda" and T.cuda.is_available())

    root_dir = os.path.dirname(os.getcwd())

    logger.info("training: experiment %s" % (cfg.exp_name))

    # make dir-tree
    exp_dir = root_dir+"/experiments/"+cfg.exp_name

    for d in ["out", "checkpoint", "logs"]:
        os.makedirs(exp_dir+"/"+d, exist_ok=True)

    cfg.to_file(exp_dir+"/"+"train_config.json")

    # tb tb_writer
    tb_writer = SummaryWriter(exp_dir+"/"+"logs")
    logger.info("started tensorboard writer")

    model = CAE()
    model.train()
    if cfg.device == "cuda":
        model.cuda()
    logger.info(f"loaded model on {cfg.device}")

    dataloader = DataLoader(
        dataset=ImageFolder720p(cfg.dataset_path),
        batch_size=cfg.batch_size,
        shuffle=cfg.shuffle,
        num_workers=cfg.num_workers,
    )
    logger.info(f"loaded dataset from {cfg.dataset_path}")

    optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=1e-5)
    loss_criterion = nn.MSELoss()

    avg_loss, epoch_avg = 0.0, 0.0
    ts = 0

    # EPOCHS
    for epoch_idx in range(cfg.start_epoch, cfg.num_epochs + 1):
        # BATCHES
        for batch_idx, data in enumerate(dataloader, start=1):
            img, patches, _ = data

            if cfg.device == "cuda":
                patches = patches.cuda()

            avg_loss_per_image = 0.0
            for i in range(6):
                for j in range(10):
                    optimizer.zero_grad()

                    x = patches[:, :, i, j, :, :]
                    y = model(x)
                    loss = loss_criterion(y, x)

                    avg_loss_per_image += (1 / 60) * loss.item()

                    loss.backward()
                    optimizer.step()

            avg_loss += avg_loss_per_image
            epoch_avg += avg_loss_per_image

            if batch_idx % cfg.batch_every == 0:
                tb_writer.add_scalar("train/avg_loss", avg_loss / cfg.batch_every, ts)

                for name, param in model.named_parameters():
                    tb_writer.add_histogram(name, param, ts)

                logger.debug(
                    "[%3d/%3d][%5d/%5d] avg_loss: %.8f"
                    % (
                        epoch_idx,
                        cfg.num_epochs,
                        batch_idx,
                        len(dataloader),
                        avg_loss / cfg.batch_every,
                    )
                )

                avg_loss = 0.0
                ts += 1
            # -- end batch every

            if batch_idx % cfg.save_every == 0:
                out = T.zeros(6, 10, 3, 128, 128)
                for i in range(6):
                    for j in range(10):
                        x = patches[0, :, i, j, :, :].unsqueeze(0).cuda()
                        out[i, j] = model(x).cpu().data

                out = np.transpose(out, (0, 3, 1, 4, 2))
                out = np.reshape(out, (768, 1280, 3))
                out = np.transpose(out, (2, 0, 1))

                y = T.cat((img[0], out), dim=2).unsqueeze(0)
                save_imgs(
                    imgs=y,
                    to_size=(3, 768, 2 * 1280),
                    name=exp_dir+"/"+f"out/{epoch_idx}_{batch_idx}.png",
                )
            # -- end save every
        # -- end batches

        if epoch_idx % cfg.epoch_every == 0:
            epoch_avg /= len(dataloader) * cfg.epoch_every

            tb_writer.add_scalar(
                "train/epoch_avg_loss",
                avg_loss / cfg.batch_every,
                epoch_idx // cfg.epoch_every,
            )

            logger.info("Epoch avg = %.8f" % epoch_avg)
            epoch_avg = 0.0

            T.save(model.state_dict(), exp_dir +"/"+ f"checkpoint/model_{epoch_idx}.pth")
        # -- end epoch every
    # -- end epoch

    # save final model
    T.save(model.state_dict(), exp_dir +"/"+ "model_final.pth")

    # cleaning
    tb_writer.close()


if __name__ == "__main__":
    #parser = argparse.ArgumentParser()
    #parser.add_argument('-f')
    #parser.add_argument("--config", type=str, required=True)
    #args = parser.parse_args()

    with open("configs/train.yaml", "rt") as fp:
        cfg = Namespace(**yaml.safe_load(fp))
        print(cfg)

    train(cfg)


{
    "num_epochs": 1,
    "batch_size": 16,
    "learning_rate": 0.0001,
    "resume": false,
    "checkpoint": null,
    "start_epoch": 1,
    "exp_name": "training",
    "batch_every": 1,
    "save_every": 10,
    "epoch_every": 1,
    "shuffle": true,
    "dataset_path": "drive/MyDrive/yt_small_720p",
    "num_workers": 2,
    "device": "cuda"
}
[73][21-11-22 13:12:21.483 @ train] INFO: training: experiment training
[73][21-11-22 13:12:21.504 @ train] INFO: started tensorboard writer
[73][21-11-22 13:12:21.539 @ train] INFO: loaded model on cuda
[73][21-11-22 13:12:21.589 @ train] INFO: loaded dataset from drive/MyDrive/yt_small_720p
[73][21-11-22 13:12:56.464 @ train] DEBUG: [  1/  1][    1/  143] avg_loss: 0.15425662
[73][21-11-22 13:13:28.992 @ train] DEBUG: [  1/  1][    2/  143] avg_loss: 0.02007326
[73][21-11-22 13:14:01.441 @ train] DEBUG: [  1/  1][    3/  143] avg_loss: 0.01710654
[73][21-11-22 13:14:33.798 @ train] DEBUG: [  1/  1][    4/  143] avg_loss: 0.01456667
[73][2

In [None]:
# test.py
import os
import yaml
import argparse
from pathlib import Path

import numpy as np
import torch as T
import torch.nn as nn
#from torch.utils.data import DataLoader

#from data_loader import ImageFolder720p
#from utils import save_imgs

# from namespace import Namespace
# from logger import Logger

from models.cae_32x32x32_zero_pad_bin import CAE

ROOT_EXP_DIR = os.path.dirname(os.getcwd())+"/"+ "experiments"

logger = Logger(__name__, colorize=True)


def test(cfg):
    assert cfg.checkpoint not in [None, ""]
    assert cfg.device == "cpu" or (cfg.device == "cuda" and T.cuda.is_available())

    exp_dir = ROOT_EXP_DIR +"/"+ cfg.exp_name
    os.makedirs(exp_dir +"/"+ "out", exist_ok=True)
    cfg.to_file(exp_dir +"/"+ "test_config.json")
    logger.info(f"[exp dir={exp_dir}]")

    model = CAE()
    model.load_state_dict(T.load(cfg.checkpoint))
    model.eval()
    if cfg.device == "cuda":
        model.cuda()
    logger.info(f"[model={cfg.checkpoint}] on {cfg.device}")

    dataloader = DataLoader(
        dataset=ImageFolder720p(cfg.dataset_path), batch_size=1, shuffle=cfg.shuffle
    )
    logger.info(f"[dataset={cfg.dataset_path}]")

    loss_criterion = nn.MSELoss()

    for batch_idx, data in enumerate(dataloader, start=1):
        img, patches, _ = data
        if cfg.device == "cuda":
            patches = patches.cuda()

        if batch_idx % cfg.batch_every == 0:
            pass

        out = T.zeros(6, 10, 3, 128, 128)
        avg_loss = 0

        for i in range(6):
            for j in range(10):
                x = patches[:, :, i, j, :, :].cuda()
                y = model(x)
                out[i, j] = y.data

                loss = loss_criterion(y, x)
                avg_loss += (1 / 60) * loss.item()

        logger.debug("[%5d/%5d] avg_loss: %f", batch_idx, len(dataloader), avg_loss)

        # save output
        out = np.transpose(out, (0, 3, 1, 4, 2))
        out = np.reshape(out, (768, 1280, 3))
        out = np.transpose(out, (2, 0, 1))

        y = T.cat((img[0], out), dim=2)
        save_imgs(
            imgs=y.unsqueeze(0),
            to_size=(3, 768, 2 * 1280),
            name=exp_dir +"/"+ f"out/test_{batch_idx}.png",
        )


if __name__ == "__main__":
    with open("configs/test.yaml", "rt") as fp:
        cfg = Namespace(**yaml.safe_load(fp))
        print(cfg)

    test(cfg)


{
    "checkpoint": "../experiments/training/model_final.pth",
    "exp_name": "testing",
    "batch_every": 100,
    "shuffle": false,
    "dataset_path": "testing",
    "num_workers": 1,
    "device": "cuda"
}
[73][21-11-22 15:11:45.776 @ test] INFO: [exp dir=//experiments/testing]
[73][21-11-22 15:11:45.836 @ test] INFO: [model=../experiments/training/model_final.pth] on cuda
[73][21-11-22 15:11:45.842 @ test] INFO: [dataset=testing]
[73][21-11-22 15:11:50.505 @ test] DEBUG: [    1/    1] avg_loss: 0.013746


In [None]:
#SSIM
import numpy as np
import cv2
from PIL import Image 
from scipy.signal import convolve2d
 
def matlab_style_gauss2D(shape=(3,3),sigma=0.5):
    """
    2D gaussian mask - should give the same result as MATLAB's
    fspecial('gaussian',[shape],[sigma])
    """
    m,n = [(ss-1.)/2. for ss in shape]
    y,x = np.ogrid[-m:m+1,-n:n+1]
    h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
    h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
    sumh = h.sum()
    if sumh != 0:
        h /= sumh
    return h
 
def filter2(x, kernel, mode='same'):
    return convolve2d(x, np.rot90(kernel, 2), mode=mode)
 
def compute_ssim(im1, im2, k1=0.01, k2=0.03, win_size=11, L=255):
 
    if not im1.shape == im2.shape:
        raise ValueError("Input Imagees must have the same dimensions")
    if len(im1.shape) > 2:
        raise ValueError("Please input the images with 1 channel")
 
    M, N = im1.shape
    C1 = (k1*L)**2
    C2 = (k2*L)**2
    window = matlab_style_gauss2D(shape=(win_size,win_size), sigma=1.5)
    window = window/np.sum(np.sum(window))
 
    if im1.dtype == np.uint8:
        im1 = np.double(im1)
    if im2.dtype == np.uint8:
        im2 = np.double(im2)
 
    mu1 = filter2(im1, window, 'valid')
    mu2 = filter2(im2, window, 'valid')
    mu1_sq = mu1 * mu1
    mu2_sq = mu2 * mu2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = filter2(im1*im1, window, 'valid') - mu1_sq
    sigma2_sq = filter2(im2*im2, window, 'valid') - mu2_sq
    sigmal2 = filter2(im1*im2, window, 'valid') - mu1_mu2
 
    ssim_map = ((2*mu1_mu2+C1) * (2*sigmal2+C2)) / ((mu1_sq+mu2_sq+C1) * (sigma1_sq+sigma2_sq+C2))
 
    return np.mean(np.mean(ssim_map))
 
 
if __name__ == "__main__":
    im1= cv2.imread("1.png")
    im1= cv2.cvtColor(im1, cv2.COLOR_BGR2GRAY)
    im2= cv2.imread("2.png")
    im2= cv2.cvtColor(im2, cv2.COLOR_BGR2GRAY)
 
    print(compute_ssim(np.array(im1),np.array(im2)))

In [None]:
#PSNR
import cv2
import numpy as np
import math
 

 
def psnr(img1, img2):
   mse = np.mean( (img1/255. - img2/255.) ** 2 )
   if mse < 1.0e-10:
      return 100
   PIXEL_MAX = 1
   return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))



im1= cv2.imread("1.png")
im2= cv2.imread("2.png")
print(psnr(im1,im2))

In [None]:
#MS-SSIM
#ssim.py
import numpy as np
import scipy.signal
import cv2


def ssim_index_new(img1,img2,K,win):

    M,N = img1.shape

    img1 = img1.astype(np.float32)
    img2 = img2.astype(np.float32)

    C1 = (K[0]*255)**2
    C2 = (K[1]*255) ** 2
    win = win/np.sum(win)

    mu1 = scipy.signal.convolve2d(img1,win,mode='valid')
    mu2 = scipy.signal.convolve2d(img2,win,mode='valid')
    mu1_sq = np.multiply(mu1,mu1)
    mu2_sq = np.multiply(mu2,mu2)
    mu1_mu2 = np.multiply(mu1,mu2)
    sigma1_sq = scipy.signal.convolve2d(np.multiply(img1,img1),win,mode='valid') - mu1_sq
    sigma2_sq = scipy.signal.convolve2d(np.multiply(img2, img2), win, mode='valid') - mu2_sq
    img12 = np.multiply(img1, img2)
    sigma12 = scipy.signal.convolve2d(np.multiply(img1, img2), win, mode='valid') - mu1_mu2

    if(C1 > 0 and C2>0):
        ssim1 =2*sigma12 + C2
        ssim_map = np.divide(np.multiply((2*mu1_mu2 + C1),(2*sigma12 + C2)),np.multiply((mu1_sq+mu2_sq+C1),(sigma1_sq+sigma2_sq+C2)))
        cs_map = np.divide((2*sigma12 + C2),(sigma1_sq + sigma2_sq + C2))
    else:
        numerator1 = 2*mu1_mu2 + C1
        numerator2 = 2*sigma12 + C2
        denominator1 = mu1_sq + mu2_sq +C1
        denominator2 = sigma1_sq + sigma2_sq +C2

        ssim_map = np.ones(mu1.shape)
        index = np.multiply(denominator1,denominator2)
        #如果index是真，就赋值，是假就原值
        n,m = mu1.shape
        for i in range(n):
            for j in range(m):
                if(index[i][j] > 0):
                    ssim_map[i][j] = numerator1[i][j]*numerator2[i][j]/denominator1[i][j]*denominator2[i][j]
                else:
                    ssim_map[i][j] = ssim_map[i][j]
        for i in range(n):
            for j in range(m):
                if((denominator1[i][j] != 0)and(denominator2[i][j] == 0)):
                    ssim_map[i][j] = numerator1[i][j]/denominator1[i][j]
                else:
                    ssim_map[i][j] = ssim_map[i][j]

        cs_map = np.ones(mu1.shape)
        for i in range(n):
            for j in range(m):
                if(denominator2[i][j] > 0):
                    cs_map[i][j] = numerator2[i][j]/denominator2[i][j]
                else:
                    cs_map[i][j] = cs_map[i][j]


    mssim = np.mean(ssim_map)
    mcs = np.mean(cs_map)

    return  mssim,mcs


def msssim(img1,img2):

    K = [0.01,0.03]
    win  = np.multiply(cv2.getGaussianKernel(11, 1.5), (cv2.getGaussianKernel(11, 1.5)).T)  # H.shape == (r, c)
    level = 5
    weight = [0.0448,0.2856,0.3001,0.2363,0.1333]
    method = 'product'

    M,N = img1.shape
    H,W = win.shape

    downsample_filter = np.ones((2,2))/4
    img1 = img1.astype(np.float32)
    img2 = img2.astype(np.float32)

    mssim_array = []
    mcs_array = []

    for i in range(0,level):
        mssim,mcs = ssim_index_new(img1,img2,K,win)
        mssim_array.append(mssim)
        mcs_array.append(mcs)
        filtered_im1 = cv2.filter2D(img1,-1,downsample_filter,anchor = (0,0),borderType=cv2.BORDER_REFLECT)
        filtered_im2 = cv2.filter2D(img2,-1,downsample_filter,anchor = (0,0),borderType=cv2.BORDER_REFLECT)
        img1 = filtered_im1[::2,::2]
        img2 = filtered_im2[::2,::2]

    print(np.power(mcs_array[:level-1],weight[:level-1]))
    print(mssim_array[level-1]**weight[level-1])
    overall_mssim = np.prod(np.power(mcs_array[:level-1],weight[:level-1]))*(mssim_array[level-1]**weight[level-1])
    print(overall_mssim)
    return overall_mssim


im1= cv2.imread("1.png")
im1= cv2.cvtColor(im1, cv2.COLOR_BGR2GRAY)
im2= cv2.imread("2.png")
im2= cv2.cvtColor(im2, cv2.COLOR_BGR2GRAY)
print(msssim(im1,im2))
