## SuperResDataset class definiton

In [35]:
import os
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset
import pytorch_lightning as pl
from tqdm import tqdm

class SuperResDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, sr_dir=None, transforms=None):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.sr_dir = sr_dir
        self.transforms = transforms if transforms else {'LR': None, 'HR': None, 'SR': None}

        self.filenames = os.listdir(lr_dir)

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        lr_image = Image.open(os.path.join(self.lr_dir, self.filenames[idx])).convert("L")
        hr_image = Image.open(os.path.join(self.hr_dir, self.filenames[idx])).convert("L")

        sample = {'LR': lr_image, 'HR': hr_image}
        
        if self.sr_dir is not None:
            sr_image = Image.open(os.path.join(self.sr_dir, self.filenames[idx])).convert("L")
            sample['SR'] = sr_image

        sample = {k: self.transforms.get(k)(v) if self.transforms.get(k) else v for k, v in sample.items()}

        return sample

## Creating dataloaders

In [36]:
os.chdir('C:\\Users\\neuro-ws\\2Image-Super-Resolution-via-Iterative-Refinement\\model_workdir\\toy_data')
def compute_mean_and_std(loader):
    means = {'LR': 0., 'HR': 0., 'SR': 0.}
    stds = {'LR': 0., 'HR': 0., 'SR': 0.}
    counts = {'LR': 0, 'HR': 0, 'SR': 0}

    for batch in tqdm(loader):
        for image_type in batch.keys():
            counts[image_type] += 1
            means[image_type] += torch.mean(batch[image_type]).item()
            stds[image_type] += torch.std(batch[image_type]).item()
            
    means = {k: v / counts[k] for k, v in means.items()}
    stds = {k: v / counts[k] for k, v in stds.items()}

    return means, stds

def create_dataset_opts(lr_dir, hr_dir, sr_dir=None, batch_size=64, shuffle=False, num_workers=0, 
                        drop_last=True, mean=None, std=None, normalize_data=True, flipping=False): #change num_workers > 0 for non-jup envs

    base_transform = [transforms.ToTensor()]
    
    # Check if need to use flipping
    if flipping:
        # This flips the image horizontally and vertically randomly with a 50% probability
        base_transform.append(transforms.RandomVerticalFlip())
        base_transform.append(transforms.RandomHorizontalFlip())
    
    transform = transforms.Compose(base_transform)
    dataset = SuperResDataset(lr_dir, hr_dir, sr_dir, transforms={'LR': transform, 'HR': transform, 'SR': transform})
    
    if mean is None or std is None: 
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
        mean, std = compute_mean_and_std(dataloader)

    if normalize_data:
        print('Normalizing')
        transform_lr = transforms.Compose(base_transform + [transforms.Normalize((mean['LR'],), (std['LR'], ))])
        transform_hr = transforms.Compose(base_transform + [transforms.Normalize((mean['HR'],), (std['HR'], ))])
        transform_sr = transforms.Compose(base_transform + [transforms.Normalize((mean['SR'],), (std['SR'], ))]) if sr_dir else None
        dataset = SuperResDataset(lr_dir, hr_dir, sr_dir, transforms={'LR': transform_lr, 'HR': transform_hr, 'SR': transform_sr})
        
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_last)
    return dataloader, mean, std


#train_dataloader, train_mean, train_std = create_dataset_opts('train/lr', 'train/hr', 'train/sr', shuffle=True, normalize_data=False)

'''
# Compute mean and std for 'lr', 'hr' and 'sr' before normalization
print("Training data statistics before normalization")
train_mean_before_norm, train_std_before_norm = compute_mean_and_std(train_dataloader)
for key in train_mean_before_norm.keys():
    print(f"{key} mean: {train_mean_before_norm[key]}, std: {train_std_before_norm[key]}")

# With normalization
print("\nNormalizing training data...\n")
train_dataloader, train_mean, train_std = create_dataset_opts('train/lr', 'train/hr', 'train/sr', shuffle=True, normalize_data=True)

# Compute mean and std for 'lr', 'hr' and 'sr' after normalization
print("Training data statistics after normalization")
train_mean_after_norm, train_std_after_norm = compute_mean_and_std(train_dataloader)
for key in train_mean_after_norm.keys():
    print(f"{key} mean: {train_mean_after_norm[key]}, std: {train_std_after_norm[key]}")

'''

'\n# Compute mean and std for \'lr\', \'hr\' and \'sr\' before normalization\nprint("Training data statistics before normalization")\ntrain_mean_before_norm, train_std_before_norm = compute_mean_and_std(train_dataloader)\nfor key in train_mean_before_norm.keys():\n    print(f"{key} mean: {train_mean_before_norm[key]}, std: {train_std_before_norm[key]}")\n\n# With normalization\nprint("\nNormalizing training data...\n")\ntrain_dataloader, train_mean, train_std = create_dataset_opts(\'train/lr\', \'train/hr\', \'train/sr\', shuffle=True, normalize_data=True)\n\n# Compute mean and std for \'lr\', \'hr\' and \'sr\' after normalization\nprint("Training data statistics after normalization")\ntrain_mean_after_norm, train_std_after_norm = compute_mean_and_std(train_dataloader)\nfor key in train_mean_after_norm.keys():\n    print(f"{key} mean: {train_mean_after_norm[key]}, std: {train_std_after_norm[key]}")\n\n'

# OPTUNA JSONS

In [3]:
#Unet: 
{
  "out_channel": {"type": "integer", "low": 1, "high": 8},
  "inner_channel": {"type": "integer", "low": 32, "high": 64},
  "norm_groups": {"type": "integer", "low": 10, "high": 64},
  "channel_mults": {"type": "choice", "options": [[1, 2, 4, 8, 8], [1, 2, 4, 4], [1, 2, 2, 2]]},
  "attn_res": {"type": "integer", "low": 4, "high": 20},
  "res_blocks": {"type": "integer", "low": 1, "high": 5},
  "dropout": {"type": "float", "low": 0.0, "high": 0.5},
  "with_noise_level_emb": {"type": "categorical", "options": [true, false]},
  "activation_function": {"type": "categorical", "full_param_dictions": ["ReLU", "LeakyReLU", "Swish", "Mish", "ELU", "PReLU", "SELU", "GLU"]}
}

#SR3:

{
  "image_size": {"type": "int", "low": 64, "high": 1024},
  "channels": {"type": "int", "low": 1, "high": 3},
  "loss_type": {"type": "categorical", "choices": ["l1", "l2"]}, # добавить SSIM
  "conditional": {"type": "categorical", "choices": [True, False]},
  "if conditional==True": {
  "noise_schedule_params": {
      "schedule": {"type": "categorical", "choices": ["quad", "linear", "warmup10", "warmup50", "const", "jsd", "cosine"]},
      "n_timestep": {"type": "int", "low": 1, "high": 1000},
      "linear_start": {"type": "float", "low": 0.0, "high": 1.0},
      "linear_end": {"type": "float", "low": 0.0, "high": 1.0}
            }
        }
    }

NameError: name 'true' is not defined

# Extrajson

In [37]:
param_dict = {
    "unet": {
        "out_channel": {"type": "int", "low": 1, "high": 1},
        "norm_groups": {"type": "int", "low": 1, "high": 100},
        "inner_channel": {"type": "int", "low": 1, "high": 20},
        "channel_multiplier": [1,2,4],
        "attn_res": {"type": "list", "elements": {"type": "int", "low": 16, "high": 64}},
        "res_blocks": {"type": "int", "low": 1, "high": 50},
        "dropout": {"type": "float", "low": 0.0, "high": 0.5},
        "with_noise_level_emb": {"type": "categorical", "options": [True, False]},
        "activation_function": {"type": "categorical", "options": ["ReLU", "LeakyReLU", "Swish", "Mish", "ELU", "PReLU", "SELU", "GLU"]}
    },
    "diffusion": {
        "image_size": {"type": "int", "low": 256, "high": 256},
        "channels": {"type": "int", "low": 1, "high": 1},
        "loss_type": {"type": "categorical", "choices": ["l1", "l2"]}, # добавить SSIM
        "conditional": {"type": "categorical", "choices": [True, False]},
    },
    "beta_schedule_train": {
        "schedule": "linear",
          "n_timestep": 6000,
          "linear_start": 1e-8,
          "linear_end": 1e-2}
}


In [172]:
b = {
    "name": "soilCT",
    "phase": "train",
    "gpu_ids": [1],
    "distributed": True,
    "path": {
        "log": "logs",
        "tb_logger": "tb_logger",
        "results": "results",
        "checkpoint": "checkpoint",
        "resume_state": None
    },
    "datasets": {
        "train": {
            "name": "soilCT",
            "mode": "LRHR",
            "dataroot": "model_workdir\\data\\train",
            "datatype": "img",
            "l_resolution": 64,
            "r_resolution": 256,
            "batch_size": 8,
            "num_workers": 16,
            "use_shuffle": True,
            "data_len": -1
        },
        "val": {
            "name": "soilVal",
            "mode": "LRHR",
            "dataroot": "model_workdir\\data\\val",
            "datatype": "img",
            "l_resolution": 64,
            "r_resolution": 256,
            "data_len": -1
        }
    },
    "model": {
        "which_model_G": "sr3",
        "finetune_norm": False,
        "unet": {
            "in_channel": 2,
            "out_channel": 1,
            "inner_channel": 64,
            "channel_multiplier": [1, 2, 4],
            "attn_res": [32],
            "res_blocks": 2,
            "dropout": 0.2
        },
        "beta_schedule": {
            "train": {
                "schedule": "linear",
                "n_timestep": 7500,
                "linear_start": 1e-8,
                "linear_end": 1e-2
            },
            "val": {
                "schedule": "linear",
                "n_timestep": 7500,
                "linear_start": 1e-8,
                "linear_end": 1e-2
            }
        },
        "diffusion": {
            "image_size": 256,
            "channels": 1,
            "conditional": True
        }
    },
    "train": {
        "n_iter": 1000000,
        "val_freq": 1e4,
        "save_checkpoint_freq": 1e4,
        "print_freq": 200,
        "optimizer": {
            "type": "adam",
            "lr": 1e-4
        },
        "ema_scheduler": {
            "step_start_ema": 5000,
            "update_ema_every": 1,
            "ema_decay": 0.9999
        }
    },
    "wandb": {
        "project": "sr_MedSoilCT"
    }
}


In [39]:
b = {**param_dict, **b}

# Print the merged dictionary
b

{'unet': {'out_channel': {'type': 'int', 'low': 1, 'high': 1},
  'norm_groups': {'type': 'int', 'low': 1, 'high': 100},
  'inner_channel': {'type': 'int', 'low': 1, 'high': 20},
  'channel_multiplier': [1, 2, 4],
  'attn_res': {'type': 'list',
   'elements': {'type': 'int', 'low': 16, 'high': 64}},
  'res_blocks': {'type': 'int', 'low': 1, 'high': 50},
  'dropout': {'type': 'float', 'low': 0.0, 'high': 0.5},
  'with_noise_level_emb': {'type': 'categorical', 'options': [True, False]},
  'activation_function': {'type': 'categorical',
   'options': ['ReLU',
    'LeakyReLU',
    'Swish',
    'Mish',
    'ELU',
    'PReLU',
    'SELU',
    'GLU']}},
 'diffusion': {'image_size': {'type': 'int', 'low': 256, 'high': 256},
  'channels': {'type': 'int', 'low': 1, 'high': 1},
  'loss_type': {'type': 'categorical', 'choices': ['l1', 'l2']},
  'conditional': {'type': 'categorical', 'choices': [True, False]}},
 'beta_schedule_train': {'schedule': 'linear',
  'n_timestep': 6000,
  'linear_start': 1e-

# Unet Module

In [167]:
import math
import torch
from torch import nn
import torch.nn.functional as F
from inspect import isfunction


def exists(x):
    return x is not None


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

# PositionalEncoding Source： https://github.com/lmnt-com/wavegrad/blob/master/src/wavegrad/model.py
class PositionalEncoding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, noise_level):
        count = self.dim // 2
        step = torch.arange(count, dtype=noise_level.dtype,
                            device=noise_level.device) / count
        encoding = noise_level.unsqueeze(
            1) * torch.exp(-math.log(1e4) * step.unsqueeze(0))
        encoding = torch.cat(
            [torch.sin(encoding), torch.cos(encoding)], dim=-1)
        return encoding


class FeatureWiseAffine(nn.Module):
    def __init__(self, in_channels, out_channels, use_affine_level=False):
        super(FeatureWiseAffine, self).__init__()
        self.use_affine_level = use_affine_level
        self.noise_func = nn.Sequential(
            nn.Linear(in_channels, out_channels*(1+self.use_affine_level))
        )

    def forward(self, x, noise_embed):
        batch = x.shape[0]
        if self.use_affine_level:
            gamma, beta = self.noise_func(noise_embed).view(
                batch, -1, 1, 1).chunk(2, dim=1)
            x = (1 + gamma) * x + beta
        else:
            x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1)
        return x


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class Upsample(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="nearest")
        self.conv = nn.Conv2d(dim, dim, 3, padding=1)

    def forward(self, x):
        return self.conv(self.up(x))


class Downsample(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv2d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)


# building block modules


class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=32, dropout=0):
        super().__init__()
        self.block = nn.Sequential(
            nn.GroupNorm(groups, dim),
            Swish(),
            nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
            nn.Conv2d(dim, dim_out, 3, padding=1)
        )

    def forward(self, x):
        return self.block(x)


class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32):
        super().__init__()
        self.noise_func = FeatureWiseAffine(
            noise_level_emb_dim, dim_out, use_affine_level)

        self.block1 = Block(dim, dim_out, groups=norm_groups)
        self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
        self.res_conv = nn.Conv2d(
            dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb):
        b, c, h, w = x.shape
        h = self.block1(x)
        h = self.noise_func(h, time_emb)
        h = self.block2(h)
        return h + self.res_conv(x)


class SelfAttention(nn.Module):
    def __init__(self, in_channel, n_head=1, norm_groups=32):
        super().__init__()

        self.n_head = n_head

        self.norm = nn.GroupNorm(norm_groups, in_channel)
        self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
        self.out = nn.Conv2d(in_channel, in_channel, 1)

    def forward(self, input):
        batch, channel, height, width = input.shape
        n_head = self.n_head
        head_dim = channel // n_head

        norm = self.norm(input)
        qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
        query, key, value = qkv.chunk(3, dim=2)  # bhdyx

        attn = torch.einsum(
            "bnchw, bncyx -> bnhwyx", query, key
        ).contiguous() / math.sqrt(channel)
        attn = attn.view(batch, n_head, height, width, -1)
        attn = torch.softmax(attn, -1)
        attn = attn.view(batch, n_head, height, width, height, width)

        out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous()
        out = self.out(out.view(batch, channel, height, width))

        return out + input


class ResnetBlocWithAttn(nn.Module):
    def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False):
        super().__init__()
        self.with_attn = with_attn
        self.res_block = ResnetBlock(
            dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout)
        if with_attn:
            self.attn = SelfAttention(dim_out, norm_groups=norm_groups)

    def forward(self, x, time_emb):
        x = self.res_block(x, time_emb)
        if(self.with_attn):
            x = self.attn(x)
        return x


class GrayscaleUNet(nn.Module):

    def __init__(
        self,
        in_channel=1,
        out_channel=1,
        inner_channel=32,
        norm_groups=32,
        channel_mults=(1, 2, 4, 8, 8),
        attn_res=(8,),
        res_blocks=3,
        dropout=0,
        with_noise_level_emb=True,
        image_size=128
    ):
        super().__init__()

        if with_noise_level_emb:
            noise_level_channel = inner_channel
            self.noise_level_mlp = nn.Sequential(
                PositionalEncoding(inner_channel),
                nn.Linear(inner_channel, inner_channel * 4),
                Swish(),
                nn.Linear(inner_channel * 4, inner_channel)
            )
        else:
            noise_level_channel = None
            self.noise_level_mlp = None

        num_mults = len(channel_mults)
        pre_channel = inner_channel
        feat_channels = [pre_channel]
        now_res = image_size
        downs = [nn.Conv2d(1, inner_channel, kernel_size=3, padding=1)]

        for ind in range(num_mults):
            is_last = (ind == num_mults - 1)
            use_attn = (now_res in attn_res)
            channel_mult = inner_channel * channel_mults[ind]
            for _ in range(0, res_blocks):
                downs.append(ResnetBlocWithAttn(
                    pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn))
                feat_channels.append(channel_mult)
                pre_channel = channel_mult
            if not is_last:
                downs.append(Downsample(pre_channel))
                feat_channels.append(pre_channel)
                now_res = now_res//2
        self.downs = nn.ModuleList(downs)

        self.mid = nn.ModuleList([
            ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
                               dropout=dropout, with_attn=True),
            ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
                               dropout=dropout, with_attn=False)
        ])

        ups = []
        for ind in reversed(range(num_mults)):
            is_last = (ind < 1)
            use_attn = (now_res in attn_res)
            channel_mult = inner_channel * channel_mults[ind]
            for _ in range(0, res_blocks+1):
                ups.append(ResnetBlocWithAttn(
                    pre_channel+feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
                        dropout=dropout, with_attn=use_attn))
                pre_channel = channel_mult
            if not is_last:
                ups.append(Upsample(pre_channel))
                now_res = now_res*2

        self.ups = nn.ModuleList(ups)

        self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups)

    def forward(self, x, time):
        t = self.noise_level_mlp(time) if exists(
            self.noise_level_mlp) else None

        feats = []
        for layer in self.downs:
            if isinstance(layer, ResnetBlocWithAttn):
                x = layer(x, t)
            else:
                x = layer(x)
            feats.append(x)

        for layer in self.mid:
            if isinstance(layer, ResnetBlocWithAttn):
                x = layer(x, t)
            else:
                x = layer(x)

        for layer in self.ups:
            if isinstance(layer, ResnetBlocWithAttn):
                x = layer(torch.cat((x, feats.pop()), dim=1), t)
            else:
                x = layer(x)

        return self.final_conv(x)


# SR3 Module (Rewrite as pl.LigthningModule if the envelope fails = uncomment the lower part of the code)

In [181]:
import math
import torch
from torch import device, nn, einsum
import torch.nn.functional as F
from inspect import isfunction
from functools import partial
import numpy as np
from tqdm import tqdm
from pytorch_msssim import ssim



def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac):
    betas = linear_end * np.ones(n_timestep, dtype=np.float64)
    warmup_time = int(n_timestep * warmup_frac)
    betas[:warmup_time] = np.linspace(
        linear_start, linear_end, warmup_time, dtype=np.float64)
    return betas


def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
    if schedule == 'quad':
        betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5,
                            n_timestep, dtype=np.float64) ** 2
    elif schedule == 'linear':
        betas = np.linspace(linear_start, linear_end,
                            n_timestep, dtype=np.float64)
    elif schedule == 'warmup10':
        betas = _warmup_beta(linear_start, linear_end,
                             n_timestep, 0.1)
    elif schedule == 'warmup50':
        betas = _warmup_beta(linear_start, linear_end,
                             n_timestep, 0.5)
    elif schedule == 'const':
        betas = linear_end * np.ones(n_timestep, dtype=np.float64)
    elif schedule == 'jsd':  # 1/T, 1/(T-1), 1/(T-2), ..., 1
        betas = 1. / np.linspace(n_timestep,
                                 1, n_timestep, dtype=np.float64)
    elif schedule == "cosine":
        timesteps = (
            torch.arange(n_timestep + 1, dtype=torch.float64) /
            n_timestep + cosine_s
        )
        alphas = timesteps / (1 + cosine_s) * math.pi / 2
        alphas = torch.cos(alphas).pow(2)
        alphas = alphas / alphas[0]
        betas = 1 - alphas[1:] / alphas[:-1]
        betas = betas.clamp(max=0.999)
    else:
        raise NotImplementedError(schedule)
    return betas

def exists(x):
    return x is not None


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


class GaussianDiffusion(nn.Module):
    def __init__(
        self,
        denoise_fn,
        image_size = 256,
        channels=2,
        loss_type='l1',
        conditional=True,
        schedule_opt=None,
        learning_rate=1e-4
    ):
        super().__init__()
        self.channels = channels
        self.image_size = image_size
        self.denoise_fn = denoise_fn
        self.loss_type = loss_type
        self.conditional = conditional
        self.learning_rate = learning_rate
        if schedule_opt is not None:
            pass
            #self.set_new_noise_schedule(schedule_opt)

    def set_loss(self, device):
        if self.loss_type == 'l1':
            self.loss_func = nn.L1Loss(reduction='sum').to(device)
        elif self.loss_type == 'l2':
            self.loss_func = nn.MSELoss(reduction='sum').to(device)
        elif self.loss_type == 'l2':
            self.loss_func = ssim.to(device)
        else:
            raise NotImplementedError()

    def set_new_noise_schedule(self, schedule_opt, device):
        to_torch = partial(torch.tensor, dtype=torch.float32, device=device)

        betas = make_beta_schedule(
            schedule=schedule_opt['schedule'],
            n_timestep=schedule_opt['n_timestep'],
            linear_start=schedule_opt['linear_start'],
            linear_end=schedule_opt['linear_end'])
        betas = betas.detach().cpu().numpy() if isinstance(
            betas, torch.Tensor) else betas
        alphas = 1. - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
        self.sqrt_alphas_cumprod_prev = np.sqrt(
            np.append(1., alphas_cumprod))

        timesteps, = betas.shape
        self.num_timesteps = int(timesteps)
        self.register_buffer('betas', to_torch(betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev',
                             to_torch(alphas_cumprod_prev))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod',
                             to_torch(np.sqrt(alphas_cumprod)))
        self.register_buffer('sqrt_one_minus_alphas_cumprod',
                             to_torch(np.sqrt(1. - alphas_cumprod)))
        self.register_buffer('log_one_minus_alphas_cumprod',
                             to_torch(np.log(1. - alphas_cumprod)))
        self.register_buffer('sqrt_recip_alphas_cumprod',
                             to_torch(np.sqrt(1. / alphas_cumprod)))
        self.register_buffer('sqrt_recipm1_alphas_cumprod',
                             to_torch(np.sqrt(1. / alphas_cumprod - 1)))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = betas * \
            (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
        self.register_buffer('posterior_variance',
                             to_torch(posterior_variance))
        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        self.register_buffer('posterior_log_variance_clipped', to_torch(
            np.log(np.maximum(posterior_variance, 1e-20))))
        self.register_buffer('posterior_mean_coef1', to_torch(
            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
        self.register_buffer('posterior_mean_coef2', to_torch(
            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))

    def predict_start_from_noise(self, x_t, t, noise):
        return self.sqrt_recip_alphas_cumprod[t] * x_t - \
            self.sqrt_recipm1_alphas_cumprod[t] * noise

    def q_posterior(self, x_start, x_t, t):
        posterior_mean = self.posterior_mean_coef1[t] * \
            x_start + self.posterior_mean_coef2[t] * x_t
        posterior_log_variance_clipped = self.posterior_log_variance_clipped[t]
        return posterior_mean, posterior_log_variance_clipped

    def p_mean_variance(self, x, t, clip_denoised: bool, condition_x=None):
        batch_size = x.shape[0]
        noise_level = torch.FloatTensor(
            [self.sqrt_alphas_cumprod_prev[t+1]]).repeat(batch_size, 1).to(x.device)
        if condition_x is not None:
            x_recon = self.predict_start_from_noise(
                x, t=t, noise=self.denoise_fn(torch.cat([condition_x, x], dim=1), noise_level))
        else:
            x_recon = self.predict_start_from_noise(
                x, t=t, noise=self.denoise_fn(x, noise_level))

        if clip_denoised:
            x_recon.clamp_(-1., 1.)

        model_mean, posterior_log_variance = self.q_posterior(
            x_start=x_recon, x_t=x, t=t)
        return model_mean, posterior_log_variance

    @torch.no_grad()
    def p_sample(self, x, t, clip_denoised=True, condition_x=None):
        model_mean, model_log_variance = self.p_mean_variance(
            x=x, t=t, clip_denoised=clip_denoised, condition_x=condition_x)
        noise = torch.randn_like(x) if t > 0 else torch.zeros_like(x)
        return model_mean + noise * (0.5 * model_log_variance).exp()

    @torch.no_grad()
    def p_sample_loop(self, x_in, continous=False):
        device = self.betas.device
        sample_inter = (1 | (self.num_timesteps//10))
        if not self.conditional:
            shape = x_in
            img = torch.randn(shape, device=device)
            ret_img = img
            for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
                img = self.p_sample(img, i)
                if i % sample_inter == 0:
                    ret_img = torch.cat([ret_img, img], dim=0)
        else:
            x = x_in
            shape = x.shape
            img = torch.randn(shape, device=device)
            ret_img = x
            for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
                img = self.p_sample(img, i, condition_x=x)
                if i % sample_inter == 0:
                    ret_img = torch.cat([ret_img, img], dim=0)
        if continous:
            return ret_img
        else:
            return ret_img[-1]

    @torch.no_grad()
    def sample(self, batch_size=1, continous=False):
        image_size = self.image_size
        channels = self.channels
        return self.p_sample_loop((batch_size, channels, image_size, image_size), continous)

    @torch.no_grad()
    def super_resolution(self, x_in, continous=False):
        return self.p_sample_loop(x_in, continous)

    def q_sample(self, x_start, continuous_sqrt_alpha_cumprod, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        # random gama
        return (
            continuous_sqrt_alpha_cumprod * x_start +
            (1 - continuous_sqrt_alpha_cumprod**2).sqrt() * noise
        )

    def p_losses(self, batch, noise=None):
        #print(batch.shape)
        x_start = batch['HR']
        [b, c, h, w] = x_start.shape
        t = np.random.randint(1, self.num_timesteps + 1)
        continuous_sqrt_alpha_cumprod = torch.FloatTensor(
            np.random.uniform(
                self.sqrt_alphas_cumprod_prev[t-1],
                self.sqrt_alphas_cumprod_prev[t],
                size=b
            )
        ).to(x_start.device)
        continuous_sqrt_alpha_cumprod = continuous_sqrt_alpha_cumprod.view(
            b, -1)

        noise = default(noise, lambda: torch.randn_like(x_start))
        x_noisy = self.q_sample(
            x_start=x_start, continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), noise=noise)

        if not self.conditional:
            x_recon = self.denoise_fn(x_noisy, continuous_sqrt_alpha_cumprod)
        else:
            x_recon = self.denoise_fn(
                torch.cat([batch['SR'], x_noisy], dim=1), continuous_sqrt_alpha_cumprod)

        ssim_value = ssim(noise, x_recon)

        loss = self.loss_func(noise, x_recon)
        return loss, ssim_value

    def forward(self, x, *args, **kwargs):
        return self.p_losses(x, *args, **kwargs)
'''
    def training_step(self, batch, batch_idx):
        loss, ssim_value = self.p_losses(batch)
        self.log('train_loss', loss, on_step=False, on_epoch=True)
        self.log('train_ssim', ssim_value, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, ssim_value = self.p_losses(batch)
        self.log('val_loss', loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log('val_ssim', ssim_value, on_step=False, on_epoch=True, sync_dist=True)
        return {"loss": loss, "ssim": ssim_value}

    def test_step(self, batch, batch_idx):
        loss, ssim_value = self.p_losses(batch)
        self.log('test_loss', loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log('test_ssim', ssim_value, on_step=False, on_epoch=True, sync_dist=True)
        return {"loss": loss, "ssim": ssim_value}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
'''
    



'\n    def training_step(self, batch, batch_idx):\n        loss, ssim_value = self.p_losses(batch)\n        self.log(\'train_loss\', loss, on_step=False, on_epoch=True)\n        self.log(\'train_ssim\', ssim_value, on_step=False, on_epoch=True)\n        return loss\n\n    def validation_step(self, batch, batch_idx):\n        loss, ssim_value = self.p_losses(batch)\n        self.log(\'val_loss\', loss, on_step=False, on_epoch=True, sync_dist=True)\n        self.log(\'val_ssim\', ssim_value, on_step=False, on_epoch=True, sync_dist=True)\n        return {"loss": loss, "ssim": ssim_value}\n\n    def test_step(self, batch, batch_idx):\n        loss, ssim_value = self.p_losses(batch)\n        self.log(\'test_loss\', loss, on_step=False, on_epoch=True, sync_dist=True)\n        self.log(\'test_ssim\', ssim_value, on_step=False, on_epoch=True, sync_dist=True)\n        return {"loss": loss, "ssim": ssim_value}\n\n    def configure_optimizers(self):\n        return torch.optim.Adam(self.parameter

# Envelope block defined as lightningmodule (noise schedule + visuals + train/test/val etc.). Skip if not used

In [182]:
import os
import torch
import torch.nn as nn
import pytorch_lightning as pl
from collections import OrderedDict
from pytorch_msssim import ssim

class BaseModel():
    def __init__(self, b):
        self.b = b
        self.device = torch.device(
            'cuda' if b['gpu_ids'] is not None else 'cpu')
        self.begin_step = 0
        self.begin_epoch = 0

    def feed_data(self, data):
        pass

    def optimize_parameters(self):
        pass

    def get_current_visuals(self):
        pass

    def get_current_losses(self):
        pass

    def print_network(self):
        pass

    def set_device(self, x):
        if isinstance(x, dict):
            for key, item in x.items():
                if item is not None:
                    x[key] = item.to(self.device)
        elif isinstance(x, list):
            for item in x:
                if item is not None:
                    item = item.to(self.device)
        else:
            x = x.to(self.device)
        return x

    def get_network_description(self, network):
        if isinstance(network, nn.DataParallel):
            network = network.module
        s = str(network)
        n = sum(map(lambda x: x.numel(), network.parameters()))
        return s, n

class DDPM(pl.LightningModule, BaseModel):
    def __init__(self, b):
        super(DDPM, self).__init__()
        # define network and load pretrained models
        self.netG = self.set_device(SR3)
        self.schedule_phase = None
        self.b = b

        # set loss and load resume state
        self.set_loss()
        self.set_new_noise_schedule(b['model']['beta_schedule']['train'])

    def feed_data(self, data):
        self.data = self.set_device(data)

    def forward(self, x):
        # Your forward pass logic here
        return self.netG(x)

    def training_step(self, batch, batch_idx):
        self.feed_data(batch)
        l_pix = self.netG(self.data)
        b, c, h, w = self.data['HR'].shape
        l_pix = l_pix.sum() / int(b * c * h * w)

        # set log
        self.log('l_pix', l_pix.item())
        return l_pix

    def validation_step(self, batch, batch_idx):
        self.feed_data(batch)
        l_pix = self.netG(self.data)
        b, c, h, w = self.data['HR'].shape
        l_pix = l_pix.sum() / int(b * c * h * w)
        ssim_value = ssim(self.data['SR'], self.SR)

        self.log('val_loss', l_pix, on_step=False, on_epoch=True, sync_dist=True)
        self.log('val_ssim', ssim_value, on_step=False, on_epoch=True, sync_dist=True)
        return {"loss": l_pix, "ssim": ssim_value}

    def test_step(self, batch, batch_idx):
        self.feed_data(batch)
        l_pix = self.netG(self.data)
        b, c, h, w = self.data['HR'].shape
        l_pix = l_pix.sum() / int(b * c * h * w)
        ssim_value = ssim(self.data['SR'], self.SR)

        self.log('test_loss', l_pix, on_step=False, on_epoch=True, sync_dist=True)
        self.log('test_ssim', ssim_value, on_step=False, on_epoch=True, sync_dist=True)
        return {"loss": l_pix, "ssim": ssim_value}

    def configure_optimizers(self):
        # Define your optimizer here
        optim_params = list(self.netG.parameters())
        optimizer = torch.optim.Adam(optim_params, lr=self.b['train']['optimizer']['lr'])
        return optimizer
    
    def optimize_parameters(self):
        self.optG = torch.optim.Adam(list(self.netG.parameters()), lr=self.b['train']['optimizer']['lr'])
        self.optG.zero_grad()
        l_pix = self.netG(self.data)
        # need to average in multi-gpu
        b, c, h, w = self.data['HR'].shape
        l_pix = l_pix.sum()/int(b*c*h*w)
        l_pix.backward()
        self.optG.step()

        # set log
        self.log_dict['l_pix'] = l_pix.item()

    def test(self, continous=False):
        self.netG.eval()
        with torch.no_grad():
            if isinstance(self.netG, nn.DataParallel):
                self.SR = self.netG.module.super_resolution(
                    self.data['SR'], continous)
            else:
                self.SR = self.netG.super_resolution(
                    self.data['SR'], continous)
        self.netG.train()

    def sample(self, batch_size=1, continous=False):
        self.netG.eval()
        with torch.no_grad():
            if isinstance(self.netG, nn.DataParallel):
                self.SR = self.netG.module.sample(batch_size, continous)
            else:
                self.SR = self.netG.sample(batch_size, continous)
        self.netG.train()

    def set_loss(self):
        if isinstance(self.netG, nn.DataParallel):
            self.netG.module.set_loss(self.device)
        else:
            self.netG.set_loss(self.device)

    def set_new_noise_schedule(self, schedule_opt, schedule_phase='train'):
        if self.schedule_phase is None or self.schedule_phase != schedule_phase:
            self.schedule_phase = schedule_phase
            if isinstance(self.netG, nn.DataParallel):
                self.netG.module.set_new_noise_schedule(
                    schedule_opt, self.device)
            else:
                self.netG.set_new_noise_schedule(schedule_opt, self.device)

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_LR=True, sample=False):
        out_dict = OrderedDict()
        if sample:
            out_dict['SAM'] = self.SR.detach().float().cpu()
        else:
            out_dict['SR'] = self.SR.detach().float().cpu()
            out_dict['INF'] = self.data['SR'].detach().float().cpu()
            out_dict['HR'] = self.data['HR'].detach().float().cpu()
            if need_LR and 'LR' in self.data:
                out_dict['LR'] = self.data['LR'].detach().float().cpu()
            else:
                out_dict['LR'] = out_dict['INF']
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = f'{self.netG.__class__.__name__} - {self.netG.module.__class__.__name__}'
        else:
            net_struc_str = f'{self.netG.__class__.__name__}'

        print(f'Network G structure: {net_struc_str}, with parameters: {n}')
        print(s)

    def save_network(self, epoch, iter_step):
        gen_path = os.path.join(
            self.b['path']['checkpoint'], 'I{}_E{}_gen.pth'.format(iter_step, epoch))
        opt_path = os.path.join(
            self.b['path']['checkpoint'], 'I{}_E{}_opt.pth'.format(iter_step, epoch))
        # gen
        network = self.netG
        if isinstance(self.netG, nn.DataParallel):
            network = network.module
        state_dict = network.state_dict()
        for key, param in state_dict.items():
            state_dict[key] = param.cpu()
        torch.save(state_dict, gen_path)
        # opt
        opt_state = {'epoch': epoch, 'iter': iter_step,
                     'scheduler': None, 'optimizer': None}
        opt_state['optimizer'] = self.optG.state_dict()
        torch.save(opt_state, opt_path)

        #logger.info(
            #'Saved model in [{:s}] ...'.format(gen_path))   WE DON'T NEED LOGGERS SINCE pl

    def load_network(self):
        load_path = self.b['path']['resume_state']
        if load_path is not None:
            #logger.info(
                #'Loading pretrained model for G [{:s}] ...'.format(load_path))  WE DON'T NEED LOGGERS SINCE pl
            gen_path = '{}_gen.pth'.format(load_path)
            opt_path = '{}_opt.pth'.format(load_path)
            # gen
            network = self.netG
            if isinstance(self.netG, nn.DataParallel):
                network = network.module
            network.load_state_dict(torch.load(
                gen_path), strict=(not self.b['model']['finetune_norm']))
            # network.load_state_dict(torch.load(
            #     gen_path), strict=False)
            if self.b['phase'] == 'train':
                # optimizer
                opt = torch.load(opt_path)
                self.optG.load_state_dict(opt['optimizer'])
                self.begin_step = opt['iter']
                self.begin_epoch = opt['epoch']

# Weight initizlization

In [183]:
from torch.nn import init

def weights_init_normal(m, std=0.02):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.normal_(m.weight.data, 0.0, std)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        init.normal_(m.weight.data, 0.0, std)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm2d') != -1:
        init.normal_(m.weight.data, 1.0, std)  # BN also uses norm
        init.constant_(m.bias.data, 0.0)


def weights_init_kaiming(m, scale=1):
    classname = m.__class__.__name__
    if classname.find('Conv2d') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
        m.weight.data *= scale
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
        m.weight.data *= scale
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm2d') != -1:
        init.constant_(m.weight.data, 1.0)
        init.constant_(m.bias.data, 0.0)


def weights_init_orthogonal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.orthogonal_(m.weight.data, gain=1)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        init.orthogonal_(m.weight.data, gain=1)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm2d') != -1:
        init.constant_(m.weight.data, 1.0)
        init.constant_(m.bias.data, 0.0)


def init_weights(net, init_type='kaiming', scale=1, std=0.02):
    # scale for 'kaiming', std for 'normal'.
    if init_type == 'normal':
        weights_init_normal_ = functools.partial(weights_init_normal, std=std)
        net.apply(weights_init_normal_)
    elif init_type == 'kaiming':
        weights_init_kaiming_ = functools.partial(
            weights_init_kaiming, scale=scale)
        net.apply(weights_init_kaiming_)
    elif init_type == 'orthogonal':
        net.apply(weights_init_orthogonal)
    else:
        raise NotImplementedError(
            'initialization method [{:s}] not implemented'.format(init_type))



In [184]:
# Creating an instance of GrayscaleUNet
Unet = GrayscaleUNet(
    in_channel=b['model']['unet']['in_channel'],
    out_channel=b['model']['unet']['out_channel'],  # assuming that the low limit equals the high limit
    inner_channel=b['model']['unet']['inner_channel'],
    channel_mults=b['model']['unet']['channel_multiplier'],  # you need to provide actual values for lists
    attn_res=b['model']['unet']['attn_res'],  # same here
    res_blocks=b['model']['unet']['res_blocks'],
    dropout=b['model']['unet']['dropout'],
)

# Creating an instance of GaussianDiffusion with the GrayscaleUNet model
SR3 = GaussianDiffusion(
    denoise_fn=Unet,
    channels=b['model']['diffusion']['channels'], 
    conditional=b['model']['diffusion']['conditional'],  # choosing the first choice as example
    schedule_opt=b['model']['beta_schedule'],  # choosing the first choice as example
)

In [185]:
init_weights(Unet, init_type='orthogonal') #pass?
init_weights(SR3, init_type='orthogonal')

# GPU distribution & PyTorch lightning utilization 

In [186]:
train_dataloader, train_mean, train_std = create_dataset_opts('train/lr', 'train/hr', 'train/sr', batch_size = 16, num_workers = 0, shuffle=True, normalize_data=False)
val_dataloader, _, _ = create_dataset_opts('val/lr', 'val/hr', 'val/sr', batch_size = 16, num_workers = 0, shuffle=True, normalize_data=False)
#uncomment the above to initialize d_loader

# Set up Lightning Trainer
trainer = pl.Trainer(
    max_epochs=99,
    gpus=[0],  # Set to 0 for CPU, or 1, 2, etc., for GPU
    log_every_n_steps=75,  # Log every n steps
    weights_summary=None,  # Choose 'full' or 'top' for more detailed model summary
)

DDPM = DDPM(b)

trainer.fit(DDPM, train_dataloader, val_dataloader)

100%|██████████| 94/94 [00:01<00:00, 68.96it/s]


{'LR': 0.4759022311327305, 'HR': 0.4755242105494154, 'SR': 0.475906620951409} {'LR': 0.18567011204171688, 'HR': 0.2027388689048747, 'SR': 0.18560686992838027}


100%|██████████| 19/19 [00:00<00:00, 66.37it/s]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


{'LR': 0.48557542813451665, 'HR': 0.48518335505535726, 'SR': 0.48557984515240316} {'LR': 0.17618995120650843, 'HR': 0.1936349021761041, 'SR': 0.17612508958891818}
Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

RuntimeError: Given groups=1, weight of size [64, 1, 3, 3], expected input[16, 2, 256, 256] to have 1 channels, but got 2 channels instead

In [11]:
import os
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset
import pytorch_lightning as pl
from tqdm import tqdm
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_msssim import ssim
import math
from torch import device, nn, einsum
import torch.nn.functional as F
from inspect import isfunction
from functools import partial
import numpy as np
from tqdm import tqdm
from pytorch_lightning.loggers import TensorBoardLogger
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
from torch.nn import init
from collections import OrderedDict
import cv2
from torchvision.utils import make_grid

b = {
    "name": "soilCT",
    "phase": "train",
    "gpu_ids": [1],
    "distributed": True,
    "path": {
        "log": "C:\\Users\\neuro-ws\\2Image-Super-Resolution-via-Iterative-Refinement\\model_workdir\\experiments",
        "tb_logger": "C:\\Users\\neuro-ws\\2Image-Super-Resolution-via-Iterative-Refinement\\model_workdir\\experiments",
        "results": "results",
        "checkpoint": "checkpoint",
        "resume_state": None
    },
    "datasets": {
        "train": {
            "name": "soilCT",
            "mode": "LRHR",
            "dataroot": "model_workdir\\data\\train",
            "datatype": "img",
            "l_resolution": 64,
            "r_resolution": 256,
            "batch_size": 12,
            "num_workers": 16,
            "use_shuffle": True,
            "data_len": 300
        },
        "val": {
            "name": "soilVal",
            "mode": "LRHR",
            "dataroot": "model_workdir\\data\\val",
            "datatype": "img",
            "l_resolution": 64,
            "r_resolution": 256,
            "data_len": -1
        }
    },
    "model": {
        "which_model_G": "sr3",
        "finetune_norm": False,
        "unet": {
            "in_channel": 2,
            "out_channel": 1,
            "inner_channel": 64,
            "channel_multiplier": [1, 2, 4],
            "attn_res": [32],
            "res_blocks": 2,
            "dropout": 0.2
        },
        "beta_schedule": {
            "train": {
                "schedule": "linear",
                "n_timestep": 4000,
                "linear_start": 1e-9,
                "linear_end": 1e-3
            },
            "val": {
                "schedule": "linear",
                "n_timestep": 4000,
                "linear_start": 1e-9,
                "linear_end": 1e-3
            }
        },
        "diffusion": {
            "image_size": 256,
            "channels": 1,
            "conditional": True
        }
    },
    "train": {
        "n_iter": 1000000,
        "val_freq": 2e3,
        "save_checkpoint_freq": 1e4,
        "print_freq": 200,
        "optimizer": {
            "type": "adam",
            "lr": 1e-4
        },
        "ema_scheduler": {
            "step_start_ema": 5000,
            "update_ema_every": 1,
            "ema_decay": 0.9999
        }
    },
}

tb_logger = TensorBoardLogger(save_dir=b['path']['tb_logger'], version=1)

class Metrics:
    def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)):
        '''
        Converts a torch Tensor into an image Numpy array
        Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
        Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
        '''
        tensor = tensor.repeat(1, 3, 1, 1)
        tensor = tensor.squeeze().float().cpu().clamp_(*min_max)  # clamp
        tensor = (tensor - min_max[0]) / \
            (min_max[1] - min_max[0])  # to range [0,1]
        n_dim = tensor.dim()
        if n_dim == 4:
            n_img = len(tensor)
            img_np = make_grid(tensor, nrow=int(
                math.sqrt(n_img)), normalize=False).numpy()
            img_np = np.transpose(img_np, (1, 2, 0))  # HWC, RGB
        elif n_dim == 3:
            img_np = tensor.numpy()
            img_np = np.transpose(img_np, (1, 2, 0))  # HWC, RGB
        elif n_dim == 2:
            img_np = tensor.numpy()
        else:
            raise TypeError(
                'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
        if out_type == np.uint8:
            img_np = (img_np * 255.0).round()
            # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
        return img_np.astype(out_type)


    def save_img(img, img_path, mode='RGB'):
        cv2.imwrite(img_path, img)
        # cv2.imwrite(img_path, img)


    def calculate_psnr(img1, img2):
        # img1 and img2 have range [0, 255]
        img1 = img1.astype(np.float64)
        img2 = img2.astype(np.float64)
        mse = np.mean((img1 - img2)**2)
        if mse == 0:
            return float('inf')
        return 20 * math.log10(255.0 / math.sqrt(mse))


    def ssim(img1, img2):
        C1 = (0.01 * 255)**2
        C2 = (0.03 * 255)**2

        img1 = img1.astype(np.float64)
        img2 = img2.astype(np.float64)
        kernel = cv2.getGaussianKernel(11, 1.5)
        window = np.outer(kernel, kernel.transpose())

        mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
        mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
        mu1_sq = mu1**2
        mu2_sq = mu2**2
        mu1_mu2 = mu1 * mu2
        sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
        sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
        sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

        ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
                                                                (sigma1_sq + sigma2_sq + C2))
        return ssim_map.mean()


    def calculate_ssim(img1, img2):
        '''calculate SSIM
        the same outputs as MATLAB's
        img1, img2: [0, 255]
        '''
        if not img1.shape == img2.shape:
            raise ValueError('Input images must have the same dimensions.')
        if img1.ndim == 2:
            return ssim(img1, img2)
        elif img1.ndim == 3:
            if img1.shape[2] == 3:
                ssims = []
                for i in range(3):
                    ssims.append(ssim(img1, img2))
                return np.array(ssims).mean()
            elif img1.shape[2] == 1:
                return ssim(np.squeeze(img1), np.squeeze(img2))
        else:
            raise ValueError('Wrong input image dimensions.')


class SuperResDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, sr_dir=None, transforms=None):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.sr_dir = sr_dir
        self.transforms = transforms if transforms else {'LR': None, 'HR': None, 'SR': None}

        self.filenames = os.listdir(lr_dir)

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        lr_image = Image.open(os.path.join(self.lr_dir, self.filenames[idx])).convert("L")
        hr_image = Image.open(os.path.join(self.hr_dir, self.filenames[idx])).convert("L")

        sample = {'LR': lr_image, 'HR': hr_image}
        
        if self.sr_dir is not None:
            sr_image = Image.open(os.path.join(self.sr_dir, self.filenames[idx])).convert("L")
            sample['SR'] = sr_image

        sample = {k: self.transforms.get(k)(v) if self.transforms.get(k) else v for k, v in sample.items()}

        return sample

os.chdir('C:\\Users\\neuro-ws\\2Image-Super-Resolution-via-Iterative-Refinement\\model_workdir\\toy_data')
def compute_mean_and_std(loader):
    means = {'LR': 0., 'HR': 0., 'SR': 0.}
    stds = {'LR': 0., 'HR': 0., 'SR': 0.}
    counts = {'LR': 0, 'HR': 0, 'SR': 0}

    for batch in tqdm(loader):
        for image_type in batch.keys():
            counts[image_type] += 1
            means[image_type] += torch.mean(batch[image_type]).item()
            stds[image_type] += torch.std(batch[image_type]).item()
            
    means = {k: v / counts[k] for k, v in means.items()}
    stds = {k: v / counts[k] for k, v in stds.items()}

    return means, stds

def create_dataset_opts(lr_dir, hr_dir, sr_dir=None, batch_size=64, shuffle=False, num_workers=0, 
                        drop_last=True, mean=None, std=None, normalize_data=True, flipping=False): #change num_workers > 0 for non-jup envs

    base_transform = [transforms.ToTensor()]
    
    # Check if need to use flipping
    if flipping:
        # This flips the image horizontally and vertically randomly with a 50% probability
        base_transform.append(transforms.RandomVerticalFlip())
        base_transform.append(transforms.RandomHorizontalFlip())
    
    transform = transforms.Compose(base_transform)
    dataset = SuperResDataset(lr_dir, hr_dir, sr_dir, transforms={'LR': transform, 'HR': transform, 'SR': transform})
    
    if mean is None or std is None: 
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
        mean, std = compute_mean_and_std(dataloader)

    if normalize_data:
        transform_lr = transforms.Compose(base_transform + [transforms.Normalize((mean['LR'],), (std['LR'], ))])
        transform_hr = transforms.Compose(base_transform + [transforms.Normalize((mean['HR'],), (std['HR'], ))])
        transform_sr = transforms.Compose(base_transform + [transforms.Normalize((mean['SR'],), (std['SR'], ))]) if sr_dir else None
        dataset = SuperResDataset(lr_dir, hr_dir, sr_dir, transforms={'LR': transform_lr, 'HR': transform_hr, 'SR': transform_sr})
        
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_last)
    return dataloader, mean, std


def exists(x):
    return x is not None


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

# PositionalEncoding Source： https://github.com/lmnt-com/wavegrad/blob/master/src/wavegrad/model.py
class PositionalEncoding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, noise_level):
        count = self.dim // 2
        step = torch.arange(count, dtype=noise_level.dtype,
                            device=noise_level.device) / count
        encoding = noise_level.unsqueeze(
            1) * torch.exp(-math.log(1e4) * step.unsqueeze(0))
        encoding = torch.cat(
            [torch.sin(encoding), torch.cos(encoding)], dim=-1)
        return encoding


class FeatureWiseAffine(nn.Module):
    def __init__(self, in_channels, out_channels, use_affine_level=False):
        super(FeatureWiseAffine, self).__init__()
        self.use_affine_level = use_affine_level
        self.noise_func = nn.Sequential(
            nn.Linear(in_channels, out_channels*(1+self.use_affine_level))
        )

    def forward(self, x, noise_embed):
        batch = x.shape[0]
        if self.use_affine_level:
            gamma, beta = self.noise_func(noise_embed).view(
                batch, -1, 1, 1).chunk(2, dim=1)
            x = (1 + gamma) * x + beta
        else:
            x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1)
        return x


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class Upsample(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="nearest")
        self.conv = nn.Conv2d(dim, dim, 3, padding=1)

    def forward(self, x):
        return self.conv(self.up(x))


class Downsample(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv2d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)


# building block modules


class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=32, dropout=0):
        super().__init__()
        self.block = nn.Sequential(
            nn.GroupNorm(groups, dim),
            Swish(),
            nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
            nn.Conv2d(dim, dim_out, 3, padding=1)
        )

    def forward(self, x):
        return self.block(x)


class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32):
        super().__init__()
        self.noise_func = FeatureWiseAffine(
            noise_level_emb_dim, dim_out, use_affine_level)

        self.block1 = Block(dim, dim_out, groups=norm_groups)
        self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
        self.res_conv = nn.Conv2d(
            dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb):
        b, c, h, w = x.shape
        h = self.block1(x)
        h = self.noise_func(h, time_emb)
        h = self.block2(h)
        return h + self.res_conv(x)


class SelfAttention(nn.Module):
    def __init__(self, in_channel, n_head=1, norm_groups=32):
        super().__init__()

        self.n_head = n_head

        self.norm = nn.GroupNorm(norm_groups, in_channel)
        self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
        self.out = nn.Conv2d(in_channel, in_channel, 1)

    def forward(self, input):
        batch, channel, height, width = input.shape
        n_head = self.n_head
        head_dim = channel // n_head

        norm = self.norm(input)
        qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
        query, key, value = qkv.chunk(3, dim=2)  # bhdyx

        attn = torch.einsum(
            "bnchw, bncyx -> bnhwyx", query, key
        ).contiguous() / math.sqrt(channel)
        attn = attn.view(batch, n_head, height, width, -1)
        attn = torch.softmax(attn, -1)
        attn = attn.view(batch, n_head, height, width, height, width)

        out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous()
        out = self.out(out.view(batch, channel, height, width))

        return out + input


class ResnetBlocWithAttn(nn.Module):
    def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False):
        super().__init__()
        self.with_attn = with_attn
        self.res_block = ResnetBlock(
            dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout)
        if with_attn:
            self.attn = SelfAttention(dim_out, norm_groups=norm_groups)

    def forward(self, x, time_emb):
        x = self.res_block(x, time_emb)
        if(self.with_attn):
            x = self.attn(x)
        return x


class GrayscaleUNet(nn.Module):

    def __init__(
        self,
        in_channel=1,
        out_channel=1,
        inner_channel=32,
        norm_groups=32,
        channel_mults=(1, 2, 4, 8, 8),
        attn_res=(8,),
        res_blocks=3,
        dropout=0,
        with_noise_level_emb=True,
        image_size=128
    ):
        super().__init__()

        if with_noise_level_emb:
            noise_level_channel = inner_channel
            self.noise_level_mlp = nn.Sequential(
                PositionalEncoding(inner_channel),
                nn.Linear(inner_channel, inner_channel * 4),
                Swish(),
                nn.Linear(inner_channel * 4, inner_channel)
            )
        else:
            noise_level_channel = None
            self.noise_level_mlp = None

        num_mults = len(channel_mults)
        pre_channel = inner_channel
        feat_channels = [pre_channel]
        now_res = image_size
        downs = [nn.Conv2d(in_channel, inner_channel, kernel_size=3, padding=1)]

        for ind in range(num_mults):
            is_last = (ind == num_mults - 1)
            use_attn = (now_res in attn_res)
            channel_mult = inner_channel * channel_mults[ind]
            for _ in range(0, res_blocks):
                downs.append(ResnetBlocWithAttn(
                    pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn))
                feat_channels.append(channel_mult)
                pre_channel = channel_mult
            if not is_last:
                downs.append(Downsample(pre_channel))
                feat_channels.append(pre_channel)
                now_res = now_res//2
        self.downs = nn.ModuleList(downs)

        self.mid = nn.ModuleList([
            ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
                               dropout=dropout, with_attn=True),
            ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
                               dropout=dropout, with_attn=False)
        ])

        ups = []
        for ind in reversed(range(num_mults)):
            is_last = (ind < 1)
            use_attn = (now_res in attn_res)
            channel_mult = inner_channel * channel_mults[ind]
            for _ in range(0, res_blocks+1):
                ups.append(ResnetBlocWithAttn(
                    pre_channel+feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
                        dropout=dropout, with_attn=use_attn))
                pre_channel = channel_mult
            if not is_last:
                ups.append(Upsample(pre_channel))
                now_res = now_res*2

        self.ups = nn.ModuleList(ups)

        self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups)

    def forward(self, x, time):
        t = self.noise_level_mlp(time) if exists(
            self.noise_level_mlp) else None
        feats = []
        for layer in self.downs:
            if isinstance(layer, ResnetBlocWithAttn):
                x = layer(x, t)
            else:
                x = layer(x)
            feats.append(x)

        for layer in self.mid:
            if isinstance(layer, ResnetBlocWithAttn):
                x = layer(x, t)
            else:
                x = layer(x)

        for layer in self.ups:
            if isinstance(layer, ResnetBlocWithAttn):
                x = layer(torch.cat((x, feats.pop()), dim=1), t)
            else:
                x = layer(x)

        return self.final_conv(x)

def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac):
    betas = linear_end * np.ones(n_timestep, dtype=np.float64)
    warmup_time = int(n_timestep * warmup_frac)
    betas[:warmup_time] = np.linspace(
        linear_start, linear_end, warmup_time, dtype=np.float64)
    return betas


def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
    if schedule == 'quad':
        betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5,
                            n_timestep, dtype=np.float64) ** 2
    elif schedule == 'linear':
        betas = np.linspace(linear_start, linear_end,
                            n_timestep, dtype=np.float64)
    elif schedule == 'warmup10':
        betas = _warmup_beta(linear_start, linear_end,
                             n_timestep, 0.1)
    elif schedule == 'warmup50':
        betas = _warmup_beta(linear_start, linear_end,
                             n_timestep, 0.5)
    elif schedule == 'const':
        betas = linear_end * np.ones(n_timestep, dtype=np.float64)
    elif schedule == 'jsd':  # 1/T, 1/(T-1), 1/(T-2), ..., 1
        betas = 1. / np.linspace(n_timestep,
                                 1, n_timestep, dtype=np.float64)
    elif schedule == "cosine":
        timesteps = (
            torch.arange(n_timestep + 1, dtype=torch.float64) /
            n_timestep + cosine_s
        )
        alphas = timesteps / (1 + cosine_s) * math.pi / 2
        alphas = torch.cos(alphas).pow(2)
        alphas = alphas / alphas[0]
        betas = 1 - alphas[1:] / alphas[:-1]
        betas = betas.clamp(max=0.999)
    else:
        raise NotImplementedError(schedule)
    return betas

def exists(x):
    return x is not None


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

class GaussianDiffusion(nn.Module):
    def __init__(
        self,
        denoise_fn,
        image_size = 256,
        channels=1,
        loss_type='l1',
        conditional=True,
        schedule_opt=None,
        learning_rate=1e-4
    ):
        super().__init__()
        self.channels = channels
        self.image_size = image_size
        self.denoise_fn = denoise_fn
        self.loss_type = loss_type
        self.conditional = conditional
        self.learning_rate = learning_rate
        if schedule_opt is not None:
            pass
            #self.set_new_noise_schedule(schedule_opt)

    def set_loss(self, device):
        if self.loss_type == 'l1':
            self.loss_func = nn.L1Loss(reduction='sum').to(device)
        elif self.loss_type == 'l2':
            self.loss_func = nn.MSELoss(reduction='sum').to(device)
        elif self.loss_type == 'l2':
            self.loss_func = ssim.to(device)
        else:
            raise NotImplementedError()

    def set_new_noise_schedule(self, schedule_opt, device):
        to_torch = partial(torch.tensor, dtype=torch.float32, device=device)

        betas = make_beta_schedule(
            schedule=schedule_opt['schedule'],
            n_timestep=schedule_opt['n_timestep'],
            linear_start=schedule_opt['linear_start'],
            linear_end=schedule_opt['linear_end'])
        betas = betas.detach().cpu().numpy() if isinstance(
            betas, torch.Tensor) else betas
        alphas = 1. - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
        self.sqrt_alphas_cumprod_prev = np.sqrt(
            np.append(1., alphas_cumprod))

        timesteps, = betas.shape
        self.num_timesteps = int(timesteps)
        self.register_buffer('betas', to_torch(betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev',
                             to_torch(alphas_cumprod_prev))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod',
                             to_torch(np.sqrt(alphas_cumprod)))
        self.register_buffer('sqrt_one_minus_alphas_cumprod',
                             to_torch(np.sqrt(1. - alphas_cumprod)))
        self.register_buffer('log_one_minus_alphas_cumprod',
                             to_torch(np.log(1. - alphas_cumprod)))
        self.register_buffer('sqrt_recip_alphas_cumprod',
                             to_torch(np.sqrt(1. / alphas_cumprod)))
        self.register_buffer('sqrt_recipm1_alphas_cumprod',
                             to_torch(np.sqrt(1. / alphas_cumprod - 1)))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = betas * \
            (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
        self.register_buffer('posterior_variance',
                             to_torch(posterior_variance))
        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        self.register_buffer('posterior_log_variance_clipped', to_torch(
            np.log(np.maximum(posterior_variance, 1e-20))))
        self.register_buffer('posterior_mean_coef1', to_torch(
            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
        self.register_buffer('posterior_mean_coef2', to_torch(
            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))

    def predict_start_from_noise(self, x_t, t, noise):
        return self.sqrt_recip_alphas_cumprod[t] * x_t - \
            self.sqrt_recipm1_alphas_cumprod[t] * noise

    def q_posterior(self, x_start, x_t, t):
        posterior_mean = self.posterior_mean_coef1[t] * \
            x_start + self.posterior_mean_coef2[t] * x_t
        posterior_log_variance_clipped = self.posterior_log_variance_clipped[t]
        return posterior_mean, posterior_log_variance_clipped

    def p_mean_variance(self, x, t, clip_denoised: bool, condition_x=None):
        batch_size = x.shape[0]
        noise_level = torch.FloatTensor(
            [self.sqrt_alphas_cumprod_prev[t+1]]).repeat(batch_size, 1).to(x.device)
        if condition_x is not None:
            x_recon = self.predict_start_from_noise(
                x, t=t, noise=self.denoise_fn(torch.cat([condition_x, x], dim=1), noise_level))
        else:
            x_recon = self.predict_start_from_noise(
                x, t=t, noise=self.denoise_fn(x, noise_level))

        if clip_denoised:
            x_recon.clamp_(-1., 1.)

        model_mean, posterior_log_variance = self.q_posterior(
            x_start=x_recon, x_t=x, t=t)
        return model_mean, posterior_log_variance

    @torch.no_grad()
    def p_sample(self, x, t, clip_denoised=True, condition_x=None):
        model_mean, model_log_variance = self.p_mean_variance(
            x=x, t=t, clip_denoised=clip_denoised, condition_x=condition_x)
        noise = torch.randn_like(x) if t > 0 else torch.zeros_like(x)
        return model_mean + noise * (0.5 * model_log_variance).exp()

    @torch.no_grad()
    def p_sample_loop(self, x_in, continous=False):
        device = self.betas.device
        sample_inter = (1 | (self.num_timesteps//10))
        if not self.conditional:
            shape = x_in
            img = torch.randn(shape, device=device)
            ret_img = img
            for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
                img = self.p_sample(img, i)
                if i % sample_inter == 0:
                    ret_img = torch.cat([ret_img, img], dim=0)
        else:
            x = x_in
            shape = x.shape
            img = torch.randn(shape, device=device)
            ret_img = x
            for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
                img = self.p_sample(img, i, condition_x=x)
                if i % sample_inter == 0:
                    ret_img = torch.cat([ret_img, img], dim=0)
        if continous:
            return ret_img
        else:
            return ret_img[-1]

    @torch.no_grad()
    def sample(self, batch_size=1, continous=False):
        image_size = self.image_size
        channels = self.channels
        return self.p_sample_loop((batch_size, channels, image_size, image_size), continous)

    @torch.no_grad()
    def super_resolution(self, x_in, continous=False):
        return self.p_sample_loop(x_in, continous)

    def q_sample(self, x_start, continuous_sqrt_alpha_cumprod, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        # random gama
        return (
            continuous_sqrt_alpha_cumprod * x_start +
            (1 - continuous_sqrt_alpha_cumprod**2).sqrt() * noise
        )

    def p_losses(self, batch, noise=None):
        x_start = batch['HR']
        [b, c, h, w] = x_start.shape
        t = np.random.randint(1, self.num_timesteps + 1)
        continuous_sqrt_alpha_cumprod = torch.FloatTensor(
            np.random.uniform(
                self.sqrt_alphas_cumprod_prev[t-1],
                self.sqrt_alphas_cumprod_prev[t],
                size=b
            )
        ).to(x_start.device)
        continuous_sqrt_alpha_cumprod = continuous_sqrt_alpha_cumprod.view(
            b, -1)

        noise = default(noise, lambda: torch.randn_like(x_start))
        x_noisy = self.q_sample(
            x_start=x_start, continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), noise=noise)

        if not self.conditional:
            x_recon = self.denoise_fn(x_noisy, continuous_sqrt_alpha_cumprod)
        else:
            x_recon = self.denoise_fn(
                torch.cat([batch['SR'], x_noisy], dim=1), continuous_sqrt_alpha_cumprod)

        ssim_value = ssim(x_start, x_recon)

        loss = self.loss_func(noise, x_recon)
        return loss, ssim_value

    def forward(self, x, *args, **kwargs):
        return self.p_losses(x, *args, **kwargs)

class BaseModel():
    def __init__(self, b):
        self.b = b
        self.device = torch.device(
            'cuda' if b['gpu_ids'] is not None else 'cpu')
        self.begin_step = 0
        self.begin_epoch = 0

    def feed_data(self, data):
        pass

    def optimize_parameters(self):
        pass

    def get_current_visuals(self):
        pass

    def get_current_losses(self):
        pass

    def print_network(self):
        pass

    def set_device(self, x):
        if isinstance(x, dict):
            for key, item in x.items():
                if item is not None:
                    x[key] = item.to(self.device)
        elif isinstance(x, list):
            for item in x:
                if item is not None:
                    item = item.to(self.device)
        else:
            x = x.to(self.device)
        return x

    def get_network_description(self, network):
        if isinstance(network, nn.DataParallel):
            network = network.module
        s = str(network)
        n = sum(map(lambda x: x.numel(), network.parameters()))
        return s, n

class DDPM(pl.LightningModule, BaseModel):
    def __init__(self, b):
        super(DDPM, self).__init__()
        self.netG = self.set_device(SR3)
        self.schedule_phase = None
        self.b = b
        self.set_loss()
        self.set_new_noise_schedule(b['model']['beta_schedule']['train'])
        self.visuals_dir = os.path.join(b['path']['log'], 'visuals')
        self.is_first_epoch = True  

    def feed_data(self, data):
        self.data = self.set_device(data)

    def forward(self, x):
        return self.netG(x)

    def training_step(self, batch, batch_idx):
        self.feed_data(batch)
        l_pix, ssim = self.netG(self.data)
        b, c, h, w = self.data['HR'].shape
        l_pix = l_pix.sum() / int(b * c * h * w)

        #self.SR = self.netG.super_resolution(self.data['HR'], continous=False)

        # set log
        self.log('l_pix', l_pix.item())
        return l_pix

    def validation_step(self, batch, batch_idx):
        self.feed_data(batch)
        l_pix, ssim = self.netG(self.data)
        b, c, h, w = self.data['HR'].shape
        l_pix = l_pix.sum() / int(b * c * h * w)
        ssim_value = ssim

        if batch_idx>2:
            print('visuals logging')
            # Generate super-resolved images directly within the loop
            sr_images = self.netG.super_resolution(self.data['HR'], continous=False)

            # Logging visuals
            result_path = os.path.join(self.visuals_dir, f'validation_{batch_idx}')
            os.makedirs(result_path, exist_ok=True)

            for idx, (lr_img, hr_img, sr_img, bc_img) in enumerate(zip(self.data['LR'], self.data['HR'], sr_images, self.data['SR'])):
                # Save LR, HR, SR, and bicubic-interpolated images
                lr_path = os.path.join(result_path, f'validation_{idx}_lr.png')
                hr_path = os.path.join(result_path, f'validation_{idx}_hr.png')
                sr_path = os.path.join(result_path, f'validation_{idx}_sr.png')
                bicubic_path = os.path.join(result_path, f'validation_{idx}_bicubic.png')
                lr_img, hr_img, sr_img, bc_img = Metrics.tensor2img(lr_img), Metrics.tensor2img(hr_img), Metrics.tensor2img(sr_img), Metrics.tensor2img(bc_img)
                Metrics.save_img(lr_img, lr_path)
                Metrics.save_img(hr_img, hr_path)
                Metrics.save_img(sr_img, sr_path)
                Metrics.save_img(bc_img, bicubic_path)

            # Log metrics
            self.log('val_loss', l_pix, on_step=False, on_epoch=True, sync_dist=True)
            self.log('val_ssim', ssim_value, on_step=False, on_epoch=True, sync_dist=True)

            return {"loss": l_pix, "ssim": ssim_value}

    def test_step(self, batch, batch_idx):
        self.feed_data(batch)
        l_pix, ssim = self.netG(self.data)
        b, c, h, w = self.data['HR'].shape
        l_pix = l_pix.sum() / int(b * c * h * w)
        ssim_value = ssim(self.data['SR'], self.SR)

        self.log('test_loss', l_pix, on_step=False, on_epoch=True, sync_dist=True)
        self.log('test_ssim', ssim_value, on_step=False, on_epoch=True, sync_dist=True)
        return {"loss": l_pix, "ssim": ssim_value}

    def configure_optimizers(self):
        optim_params = list(self.netG.parameters())
        optimizer = torch.optim.Adam(optim_params, lr=self.b['train']['optimizer']['lr'])
        return optimizer
    
    def optimize_parameters(self):
        self.optG = torch.optim.Adam(list(self.netG.parameters()), lr=self.b['train']['optimizer']['lr'])
        self.optG.zero_grad()
        l_pix, ssim = self.netG(self.data)
        b, c, h, w = self.data['HR'].shape
        l_pix = l_pix.sum()/int(b*c*h*w)
        l_pix.backward()
        self.optG.step()

        # set log
        self.log_dict['l_pix'] = l_pix.item()

    def test(self, continous=False):
        self.netG.eval()
        with torch.no_grad():
            if isinstance(self.netG, nn.DataParallel):
                self.SR = self.netG.module.super_resolution(
                    self.data['SR'], continous)
            else:
                self.SR = self.netG.super_resolution(
                    self.data['SR'], continous)
        self.netG.train()

    def sample(self, batch_size=1, continous=False):
        self.netG.eval()
        with torch.no_grad():
            if isinstance(self.netG, nn.DataParallel):
                self.SR = self.netG.module.sample(batch_size, continous)
            else:
                self.SR = self.netG.sample(batch_size, continous)
        self.netG.train()

    def set_loss(self):
        if isinstance(self.netG, nn.DataParallel):
            self.netG.module.set_loss(self.device)
        else:
            self.netG.set_loss(self.device)

    def set_new_noise_schedule(self, schedule_opt, schedule_phase='train'):
        if self.schedule_phase is None or self.schedule_phase != schedule_phase:
            self.schedule_phase = schedule_phase
            if isinstance(self.netG, nn.DataParallel):
                self.netG.module.set_new_noise_schedule(
                    schedule_opt, self.device)
            else:
                self.netG.set_new_noise_schedule(schedule_opt, self.device)

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_LR=True, sample=False):
        out_dict = OrderedDict()
        if sample:
            out_dict['SAM'] = self.SR.detach().float().cpu()
        else:
            out_dict['SR'] = self.SR.detach().float().cpu()
            out_dict['INF'] = self.data['SR'].detach().float().cpu()
            out_dict['HR'] = self.data['HR'].detach().float().cpu()
            if need_LR and 'LR' in self.data:
                out_dict['LR'] = self.data['LR'].detach().float().cpu()
            else:
                out_dict['LR'] = out_dict['INF']
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = f'{self.netG.__class__.__name__} - {self.netG.module.__class__.__name__}'
        else:
            net_struc_str = f'{self.netG.__class__.__name__}'

        print(f'Network G structure: {net_struc_str}, with parameters: {n}')

    def save_network(self, epoch, iter_step):
        gen_path = os.path.join(
            self.b['path']['checkpoint'], 'I{}_E{}_gen.pth'.format(iter_step, epoch))
        opt_path = os.path.join(
            self.b['path']['checkpoint'], 'I{}_E{}_opt.pth'.format(iter_step, epoch))
        # gen
        network = self.netG
        if isinstance(self.netG, nn.DataParallel):
            network = network.module
        state_dict = network.state_dict()
        for key, param in state_dict.items():
            state_dict[key] = param.cpu()
        torch.save(state_dict, gen_path)
        # opt
        opt_state = {'epoch': epoch, 'iter': iter_step,
                     'scheduler': None, 'optimizer': None}
        opt_state['optimizer'] = self.optG.state_dict()
        torch.save(opt_state, opt_path)

    def load_network(self):
        load_path = self.b['path']['resume_state']
        if load_path is not None:
            gen_path = '{}_gen.pth'.format(load_path)
            opt_path = '{}_opt.pth'.format(load_path)
            network = self.netG
            if isinstance(self.netG, nn.DataParallel):
                network = network.module
            network.load_state_dict(torch.load(
                gen_path), strict=(not self.b['model']['finetune_norm']))
            if self.b['phase'] == 'train':
                opt = torch.load(opt_path)
                self.optG.load_state_dict(opt['optimizer'])
                self.begin_step = opt['iter']
                self.begin_epoch = opt['epoch']

def weights_init_normal(m, std=0.02):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.normal_(m.weight.data, 0.0, std)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        init.normal_(m.weight.data, 0.0, std)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm2d') != -1:
        init.normal_(m.weight.data, 1.0, std)  # BN also uses norm
        init.constant_(m.bias.data, 0.0)

def weights_init_kaiming(m, scale=1):
    classname = m.__class__.__name__
    if classname.find('Conv2d') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
        m.weight.data *= scale
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
        m.weight.data *= scale
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm2d') != -1:
        init.constant_(m.weight.data, 1.0)
        init.constant_(m.bias.data, 0.0)

def weights_init_orthogonal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.orthogonal_(m.weight.data, gain=1)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        init.orthogonal_(m.weight.data, gain=1)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm2d') != -1:
        init.constant_(m.weight.data, 1.0)
        init.constant_(m.bias.data, 0.0)


def init_weights(net, init_type='kaiming', scale=1, std=0.02):
    if init_type == 'normal':
        weights_init_normal_ = functools.partial(weights_init_normal, std=std)
        net.apply(weights_init_normal_)
    elif init_type == 'kaiming':
        weights_init_kaiming_ = functools.partial(
            weights_init_kaiming, scale=scale)
        net.apply(weights_init_kaiming_)
    elif init_type == 'orthogonal':
        net.apply(weights_init_orthogonal)
    else:
        raise NotImplementedError(
            'initialization method [{:s}] not implemented'.format(init_type))


Unet = GrayscaleUNet(
    in_channel=b['model']['unet']['in_channel'],
    out_channel=b['model']['unet']['out_channel'],  # assuming that the low limit equals the high limit
    inner_channel=b['model']['unet']['inner_channel'],
    channel_mults=b['model']['unet']['channel_multiplier'],  # you need to provide actual values for lists
    attn_res=b['model']['unet']['attn_res'],  # same here
    res_blocks=b['model']['unet']['res_blocks'],
    dropout=b['model']['unet']['dropout'],
)

SR3 = GaussianDiffusion(
    denoise_fn=Unet,
    channels=b['model']['diffusion']['channels'], 
    conditional=b['model']['diffusion']['conditional'],  # choosing the first choice as example
    schedule_opt=b['model']['beta_schedule'],  # choosing the first choice as example
)

init_weights(Unet, init_type='orthogonal') #pass?
init_weights(SR3, init_type='orthogonal')

train_dataloader, train_mean, train_std = create_dataset_opts('train/lr', 'train/hr', 'train/sr', batch_size = 8, num_workers = 0, shuffle=True, normalize_data=False)
val_dataloader, _, _ = create_dataset_opts('val/lr', 'val/hr', 'val/sr', batch_size = 8, num_workers = 0, shuffle=True, normalize_data=False)
trainer = pl.Trainer(
    max_epochs=99,
    gpus=[0],  # Set to 0 for CPU, or 1, 2, etc., for GPU
    check_val_every_n_epoch=2,  
    log_every_n_steps=180,
    weights_summary=None,  # Choose 'full' or 'top' for more detailed model summary
    logger=tb_logger,
    
)
DDPM = DDPM(b)

trainer.fit(DDPM, train_dataloader, val_dataloader)

Epoch 1:  85%|████████▌ | 191/224 [13:48<02:22,  4.32s/it, loss=0.122, v_num=1]
Epoch 1:  85%|████████▌ | 191/224 [06:28<01:06,  2.03s/it, loss=0.106, v_num=1]



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 188/188 [00:01<00:00, 132.96it/s]

[A
[A
100%|██████████| 38/38 [00:00<00:00, 137.59it/s]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


                                                                      

  rank_zero_warn(


Epoch 1:  85%|████████▌ | 191/224 [01:46<00:18,  1.81it/s, loss=0.118, v_num=1] visuals logging



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Epoch 1:  85%|████████▌ | 191/224 [01:57<00:20,  1.64it/s, loss=0.118, v_num=1]


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Epoch 1:  86%|████████▌ | 192/224 [13:23<02:13,  4.16s/it, loss=0.118, v_num=1]




visuals logging



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
sampling loop time step:   5%|▍         | 182/4000 [00:31<11:05,  5.74it/s]
  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [2]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()
