In [None]:
import torch
import torch.nn as nn
from libs.models import encoder3,encoder4
from libs.models import decoder3,decoder4
import numpy as np
from libs.Matrix import MulLayer
from libs.Criterion import LossCriterion
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from libs.Loader import Dataset
import os
from typing import List, Tuple
from tqdm import tqdm
import torch.nn.functional as F

class LossCriterion(nn.Module):
    def __init__(self, style_layers, content_layers, style_weight, content_weight):
        super(LossCriterion, self).__init__()
        self.style_layers = style_layers
        self.content_layers = content_layers
        self.style_weight = style_weight
        self.content_weight = content_weight
        self.styleLosses = [styleLoss()] * len(style_layers)
        self.contentLosses = [nn.MSELoss()] * len(content_layers)

    def forward(self, tF, sF, cF):
        # Content loss
        totalContentLoss = 0
        for i, layer in enumerate(self.content_layers):
            cf_i = cF[layer].detach()
            tf_i = tF[layer]
            loss_i = self.contentLosses[i]
            totalContentLoss += loss_i(tf_i, cf_i)
        totalContentLoss = totalContentLoss * self.content_weight

        # Style loss
        
        totalStyleLoss = 0
        for i, layer in enumerate(self.style_layers):
            sf_i = sF[layer].detach()
            tf_i = tF[layer]
            loss_i = self.styleLosses[i]
            totalStyleLoss += loss_i(tf_i, sf_i)
        totalStyleLoss = totalStyleLoss * self.style_weight

        loss = totalStyleLoss + totalContentLoss
        return loss, totalStyleLoss, totalContentLoss


class styleLoss(nn.Module):
    def forward(self,input,target):
        ib,ic,ih,iw = input.size()
        iF = input.view(ib,ic,-1)
        iMean = torch.mean(iF,dim=2)
        iCov = GramMatrix()(input)

        tb,tc,th,tw = target.size()
        tF = target.view(tb,tc,-1)
        tMean = torch.mean(tF,dim=2)
        tCov = GramMatrix()(target)

        loss = nn.MSELoss(size_average=False)(iMean,tMean) + nn.MSELoss(size_average=False)(iCov,tCov)
        return loss/tb

class GramMatrix(nn.Module):
    def forward(self,input):
        b, c, h, w = input.size()
        f = input.view(b,c,h*w) # bxcx(hxw)
        # torch.bmm(batch1, batch2, out=None)   #
        # batch1: bxmxp, batch2: bxpxn -> bxmxn #
        G = torch.bmm(f,f.transpose(1,2)) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
        return G.div_(c*h*w)

class LossSensitivity:
    def __init__(self, vgg: nn.Module, dec: nn.Module, matrix: MulLayer,
                 style_layers: List[str], content_layers: List[str],
                 style_weight: float, content_weight: float, device: torch.device):
        self.vgg = vgg.to(device)
        self.dec = dec.to(device)
        self.matrix = matrix.to(device)
        self.style_layers = style_layers
        self.content_layers = content_layers
        self.criterion = LossCriterion(style_layers, content_layers, style_weight, content_weight)
        self.device = device

    def add_noise(self, matrix: torch.Tensor, sigma: float) -> torch.Tensor:
        """Adds random Gaussian noise to a matrix."""
        return matrix + torch.randn_like(matrix) * sigma

    @torch.no_grad()
    def forward(self, contentV: torch.Tensor, styleV: torch.Tensor) -> Tuple[dict, dict]:
        return self.vgg(styleV), self.vgg(contentV)

    def compute_loss(self, contentV: torch.Tensor, styleV: torch.Tensor, noisy_matrix: torch.Tensor) -> float:
        sF, cF = self.forward(contentV, styleV)

        transformed_features, _ = self.matrix(cF[self.style_layers[0]], sF[self.style_layers[0]])
        b, c, h, w = transformed_features.size()
        compressed_features = self.matrix.compress(transformed_features)

        if noisy_matrix.size(1) != compressed_features.view(b, self.matrix.matrixSize, -1).size(1):
            print(f"Dimension mismatch: {noisy_matrix.size()} vs {compressed_features.size()}")
            return float('inf')

        noisy_transfeature = torch.bmm(noisy_matrix, compressed_features.view(b, self.matrix.matrixSize, -1))
        noisy_transfeature = noisy_transfeature.view(b, self.matrix.matrixSize, h, w)
        noisy_transfeature = self.matrix.unzip(noisy_transfeature)

        noisy_transfer = self.dec(noisy_transfeature).clamp(0, 1)
        tF = self.vgg(noisy_transfer)

        total_loss, _, _ = self.criterion(tF, sF, cF)
        return total_loss.item()

    # def generate_stylized_image(self, contentV: torch.Tensor, styleV: torch.Tensor, noisy_matrix: torch.Tensor) -> torch.Tensor:
    #     """Generates the stylized image using the given noisy transformation matrix."""
    #     sF, cF = self.forward(contentV, styleV)

    #     transformed_features, _ = self.matrix(cF[self.style_layers[0]], sF[self.style_layers[0]])
    #     b, c, h, w = transformed_features.size()
    #     compressed_features = self.matrix.compress(transformed_features)

    #     noisy_transfeature = torch.bmm(noisy_matrix, compressed_features.view(b, self.matrix.matrixSize, -1))
    #     noisy_transfeature = noisy_transfeature.view(b, self.matrix.matrixSize, h, w)
    #     noisy_transfeature = self.matrix.unzip(noisy_transfeature)

    #     noisy_transfer = self.dec(noisy_transfeature).clamp(0, 1)
    #     return noisy_transfer

    def run_experiment(self, contentV: torch.Tensor, styleV: torch.Tensor,
                       sigmas: np.ndarray, matrix: torch.Tensor) -> List[float]:
        """Runs the experiment for different noise levels and computes loss for each."""
        loss_values = []

        for sigma in sigmas:
            noisy_matrix = self.add_noise(matrix, sigma)

            loss = self.compute_loss(contentV, styleV, noisy_matrix)
            if loss == float('inf'):
                print(f"Skipping sigma {sigma} due to dimension mismatch.")
                continue

            loss_values.append(loss)  # Store only the loss related to random noise

        return loss_values


def process_style_dir(style_dir: str, opt, loss_sensitivity: LossSensitivity,
                      sigmas: np.ndarray, device: torch.device) -> np.ndarray:
    style_path = os.path.join(opt.matrixPath, style_dir)
    matrix_files = [f for f in os.listdir(style_path) if f.endswith('.pth')]

    total_loss_values = np.zeros(len(sigmas))
    num_matrices = 0

    content_dataset = Dataset(opt.contentPath, opt.loadSize, opt.fineSize)
    style_dataset = Dataset(opt.stylePath, opt.loadSize, opt.fineSize)

    # Loop over all matrices saved for this style
    for matrix_file in tqdm(matrix_files, desc=f"Processing {style_dir}"):
        matrix_path = os.path.join(style_path, matrix_file)
        saved_matrix = torch.load(matrix_path, map_location=device)

        num_content_images = 0
        loss_values_accumulated = np.zeros(len(sigmas))

        # Loop over all content images
        for contentV, _ in content_dataset:
            contentV = contentV.unsqueeze(0).to(device)
            styleV = style_dataset[0][0].unsqueeze(0).to(device)

            # Run experiment for this content image
            loss_values = loss_sensitivity.run_experiment(contentV, styleV, sigmas, saved_matrix)

            # Accumulate losses for each sigma
            loss_values_accumulated += np.array(loss_values)
            num_content_images += 1

        # Average loss values for each content image and add to total losses
        total_loss_values += loss_values_accumulated / num_content_images
        num_matrices += 1

    # Average the accumulated results over the number of style matrices
    avg_loss_values = total_loss_values / num_matrices

    return avg_loss_values



def plot_style_results(style_dir: str, sigmas: np.ndarray, avg_loss_values: np.ndarray):
    plt.figure(figsize=(10, 6))

    plt.plot(sigmas, avg_loss_values, '-o')
    plt.xlabel('Sigma (Noise Level)')
    plt.ylabel('Average Loss')
    plt.title(f'Loss Sensitivity for Style: {style_dir}')
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(f'noise_sensitivity_{style_dir}.png')
    plt.show()


def plot_intermediate_images(intermediate_images: List[Tuple[float, torch.Tensor]], style_dir: str):
    """Plot intermediate stylized images for the given sigma levels."""
    plt.figure(figsize=(15, 5))
    num_images = len(intermediate_images)

    for idx, (sigma, image_tensor) in enumerate(intermediate_images):
        plt.subplot(1, num_images, idx + 1)
        image = image_tensor.squeeze().permute(1, 2, 0).numpy()
        plt.imshow(image)
        plt.title(f'Sigma = {sigma:.2f}')
        plt.axis('off')

    plt.suptitle(f'Stylized Images for Style: {style_dir}')
    plt.tight_layout()
    plt.savefig(f'stylized_images_{style_dir}.png')
    plt.show()


def plot_style_loss_trend_for_noise_level(style_dir: str, sigmas: np.ndarray, avg_loss_values: np.ndarray):
    plt.figure(figsize=(10, 6))
    plt.plot(sigmas, avg_loss_values, '-o')
    plt.xlabel('Sigma (Noise Level)')
    plt.ylabel('Average Loss Across Content Images')
    plt.title(f'Average Loss for Style: {style_dir}')
    plt.grid(True)
    plt.tight_layout()
    plt.show()


def load_models(device: torch.device) -> Tuple[nn.Module, nn.Module, MulLayer]:
    vgg = encoder4()
    dec = decoder4()
    matrix = MulLayer('r41')
    vgg.load_state_dict(torch.load('models/vgg_r41.pth', map_location=device))
    dec.load_state_dict(torch.load('models/dec_r41.pth', map_location=device))
    return vgg, dec, matrix

class Options:
    def __init__(self):
        self.contentPath = "data/content/"
        self.stylePath = "data/style/"
        self.loadSize = 256
        self.fineSize = 256
        self.matrixPath = "Matrices/"

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load models
    vgg, dec, matrix = load_models(device)

    # Define options and parameters
    opt = Options()
    sigmas = np.linspace(0, 200, 100)  # Noise levels to test

    # Initialize the loss sensitivity object with valid layer identifiers
    loss_sensitivity = LossSensitivity(vgg, dec, matrix, style_layers=['r41'], content_layers=['r41'],
                                       style_weight=1.0, content_weight=1.0, device=device)

    # Loop over style directories
    style_dirs = os.listdir(opt.matrixPath)

    for style_dir in style_dirs:
        # Process each style directory to get average loss values for each sigma
        avg_loss_values = process_style_dir(style_dir, opt, loss_sensitivity, sigmas, device)

        # Plot the trend of average loss across content images for this style
        plot_style_loss_trend_for_noise_level(style_dir, sigmas, avg_loss_values)

if __name__ == "__main__":
    main()