In [None]:
class Options:
    def __init__(self):
        self.contentPath = "data/content/"
        self.stylePath = "data/style/"
        self.loadSize = 256
        self.fineSize = 256
        self.matrixPath = "Matrices/"
        self.vgg_dir = 'models/vgg_r41.pth'
        self.decoder_dir = 'models/dec_r41.pth'
        self.layer = 'r41'
        self.outf = "Artistic/rotation/"
        self.cuda = torch.cuda.is_available()
        self.batchSize = 1
        self.matrixPath = 'models/r41.pth'

if __name__ == "__main__":
    import os
    import torch
    import numpy as np
    import matplotlib.pyplot as plt
    from PIL import Image
    import torchvision.transforms as transforms
    from tqdm import tqdm
    from libs.Loader import Dataset
    from libs.Matrix import MulLayer
    from libs.models import encoder3, encoder4, decoder3, decoder4
    from libs.utils import print_options
    import torch.backends.cudnn as cudnn

    def rotate_matrix(matrix, theta_degrees):
        """
        Directly rotate the transformation matrix using rotation in multiple dimensions
        """
        # Convert to numpy for easier manipulation
        matrix_np = matrix.cpu().numpy()
        
        # Get the shape
        original_shape = matrix_np.shape
        
        # Convert angle to radians
        theta = np.radians(theta_degrees)
        
        # Create rotation matrices for each pair of dimensions
        rotated = matrix_np.copy()
        
        # Rotate in multiple planes (taking pairs of dimensions)
        for i in range(original_shape[1] - 1):
            for j in range(i + 1, original_shape[1]):
                # Create rotation matrix for this plane
                rot_matrix = np.eye(original_shape[1])
                rot_matrix[i, i] = np.cos(theta)
                rot_matrix[i, j] = -np.sin(theta)
                rot_matrix[j, i] = np.sin(theta)
                rot_matrix[j, j] = np.cos(theta)
                
                # Apply rotation
                rotated = np.matmul(rotated, rot_matrix)
        
        # Convert back to tensor
        return torch.tensor(rotated, device=matrix.device, dtype=matrix.dtype)

    opt = Options()
    print_options(opt)

    os.makedirs(opt.outf, exist_ok=True)
    cudnn.benchmark = True

    ################# MODEL #################
    if opt.layer == 'r31':
        vgg = encoder3()
        dec = decoder3()
    elif opt.layer == 'r41':
        vgg = encoder4()
        dec = decoder4()
    matrix = MulLayer(opt.layer)
    vgg.load_state_dict(torch.load(opt.vgg_dir))
    dec.load_state_dict(torch.load(opt.decoder_dir))
    matrix.load_state_dict(torch.load(opt.matrixPath))

    ################# GPU #################
    if opt.cuda:
        vgg.cuda()
        dec.cuda()
        matrix.cuda()

    content_files = [f for f in os.listdir(opt.contentPath) if f.endswith(('.jpg', '.jpeg', '.png'))]
    style_files = [f for f in os.listdir(opt.stylePath) if f.endswith(('.jpg', '.jpeg', '.png'))]

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((opt.fineSize, opt.fineSize))
    ])

    # Define rotation angles
    thetas = list(range(0, 10, 1))  # 0 to 180 degrees in steps of 30

    # Add progress bar for the style loop
    for style in tqdm(style_files, desc="Processing Styles"):
        style_image = Image.open(opt.stylePath + style).convert('RGB')
        style_tensor = transform(style_image).unsqueeze(0)

        # Add progress bar for the content loop
        for content in tqdm(content_files, desc="Processing Contents", leave=False):
            content_image = Image.open(opt.contentPath + content).convert('RGB')
            content_tensor = transform(content_image).unsqueeze(0)

            contentV = torch.Tensor(1, 3, opt.fineSize, opt.fineSize).copy_(content_tensor)
            styleV = torch.Tensor(1, 3, opt.fineSize, opt.fineSize).copy_(style_tensor)
            
            if opt.cuda:
                contentV = contentV.cuda()
                styleV = styleV.cuda()

            ################# FORWARD PASS WITH ROTATION #################
            images = []  # List to store the images with rotation

            with torch.no_grad():
                sF = vgg(styleV)
                cF = vgg(contentV)

                # Get the original transformation matrix
                if opt.layer == 'r41':
                    feature, transmatrix = matrix(cF[opt.layer], sF[opt.layer], trans=True)
                else:
                    feature, transmatrix = matrix(cF, sF, trans=True)

                compress_content = matrix.compress(cF[opt.layer] if opt.layer == 'r41' else cF)
                b, c, h, w = compress_content.size()
                compress_content = compress_content.view(b, c, -1)

                # Process each rotation angle
                for theta in thetas:
                    # Print shape information for debugging
                    #print(f"Matrix shape before rotation: {transmatrix.shape}")
                    
                    # Rotate the transformation matrix
                    rotated_matrix = rotate_matrix(transmatrix, theta)
                    
                    # Apply the rotated matrix
                    transfeature = torch.bmm(rotated_matrix, compress_content).view(b, matrix.matrixSize, h, w)
                    out = matrix.unzip(transfeature)
                    out = out + torch.mean(cF[opt.layer if opt.layer == 'r41' else cF], dim=(2, 3), keepdim=True)
                    
                    transfer_rotated = dec(out)
                    transfer_rotated = transfer_rotated.clamp(0, 1)
                    
                    img_numpy = transfer_rotated.squeeze().cpu().numpy().transpose(1, 2, 0)
                    images.append(img_numpy)

                    torch.cuda.empty_cache()

            # Plot the results
            num_rows = 2
            num_cols = len(thetas) // 2 + len(thetas) % 2
            fig, axes = plt.subplots(num_rows, num_cols, figsize=(20, 8))
            
            for idx, ax in enumerate(axes.flatten()):
                if idx < len(images):
                    ax.imshow(images[idx])
                    ax.axis('off')
                    ax.set_title(f'θ={thetas[idx]}°')
                else:
                    ax.axis('off')
                    
            plt.tight_layout()
            plt.savefig(f'{opt.outf}rotation_experiment_{style}_{content}.png')
            plt.close()

            # Also save individual images
            for idx, img in enumerate(images):
                plt.imsave(f'{opt.outf}rotation_{style}_{content}_angle_{thetas[idx]}.png', img)

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm
from libs.Matrix import MulLayer
from libs.models import encoder3, encoder4, decoder3, decoder4
from libs.utils import print_options
import torch.backends.cudnn as cudnn

class Options:
    def __init__(self):
        self.contentPath = "data/content/"
        self.stylePath = "data/style/"
        self.loadSize = 256
        self.fineSize = 256
        self.matrixPath = "Matrices/"
        self.vgg_dir = 'models/vgg_r41.pth'
        self.decoder_dir = 'models/dec_r41.pth'
        self.layer = 'r41'
        self.outf = "Artistic/translation/"
        self.cuda = torch.cuda.is_available()
        self.batchSize = 1
        self.matrixPath = 'models/r41.pth'

def translate_matrix(matrix, x_shift, y_shift):
    """
    Apply translation to the transformation matrix by shifting in the spatial dimensions.
    """
    matrix_np = matrix.cpu().numpy() if matrix.is_cuda else matrix.numpy()
    
    # Shift along height and width (2nd and 3rd dimensions if 3D matrix)
    translated = np.roll(matrix_np, shift=(x_shift, y_shift), axis=(1, 2))  # Apply translation to height and width
    return torch.tensor(translated, device=matrix.device, dtype=matrix.dtype)

if __name__ == "__main__":
    opt = Options()
    print_options(opt)

    os.makedirs(opt.outf, exist_ok=True)
    cudnn.benchmark = True

    ################# MODEL #################
    if opt.layer == 'r31':
        vgg = encoder3()
        dec = decoder3()
    elif opt.layer == 'r41':
        vgg = encoder4()
        dec = decoder4()
    matrix = MulLayer(opt.layer)
    vgg.load_state_dict(torch.load(opt.vgg_dir))
    dec.load_state_dict(torch.load(opt.decoder_dir))
    matrix.load_state_dict(torch.load(opt.matrixPath))

    ################# GPU #################
    if opt.cuda:
        vgg.cuda()
        dec.cuda()
        matrix.cuda()

    content_files = [f for f in os.listdir(opt.contentPath) if f.endswith(('.jpg', '.jpeg', '.png'))]
    style_files = [f for f in os.listdir(opt.stylePath) if f.endswith(('.jpg', '.jpeg', '.png'))]

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((opt.fineSize, opt.fineSize))
    ])

    # Define small translations
    translations = [(x, y) for x in range(6) for y in range(6) if (x != 0 or y != 0)]

    # Processing loop
    for style in tqdm(style_files, desc="Processing Styles"):
        style_image = Image.open(os.path.join(opt.stylePath, style)).convert('RGB')
        style_tensor = transform(style_image).unsqueeze(0)

        for content in tqdm(content_files, desc="Processing Contents", leave=False):
            content_image = Image.open(os.path.join(opt.contentPath, content)).convert('RGB')
            content_tensor = transform(content_image).unsqueeze(0)

            contentV = torch.Tensor(1, 3, opt.fineSize, opt.fineSize).copy_(content_tensor)
            styleV = torch.Tensor(1, 3, opt.fineSize, opt.fineSize).copy_(style_tensor)

            if opt.cuda:
                contentV = contentV.cuda()
                styleV = styleV.cuda()

            ################# FORWARD PASS WITH TRANSLATION #################
            with torch.no_grad():
                sF = vgg(styleV)
                cF = vgg(contentV)

                if opt.layer == 'r41':
                    feature, transmatrix = matrix(cF[opt.layer], sF[opt.layer], trans=True)
                else:
                    feature, transmatrix = matrix(cF, sF, trans=True)

                compress_content = matrix.compress(cF[opt.layer] if opt.layer == 'r41' else cF)
                b, c, h, w = compress_content.size()
                compress_content = compress_content.view(b, c, -1)

                # Apply each translation and save images individually
                for (x_shift, y_shift) in translations:
                    translated_matrix = translate_matrix(transmatrix, x_shift, y_shift)
                    
                    transfeature = torch.bmm(translated_matrix, compress_content).view(b, matrix.matrixSize, h, w)
                    out = matrix.unzip(transfeature)
                    out = out + torch.mean(cF[opt.layer if opt.layer == 'r41' else cF], dim=(2, 3), keepdim=True)
                    
                    transfer_translated = dec(out).clamp(0, 1)
                    img_numpy = transfer_translated.squeeze().cpu().numpy().transpose(1, 2, 0)
                    
                    # Save each translated image with a title indicating translation amount
                    fig, ax = plt.subplots(figsize=(5, 5))
                    ax.imshow(img_numpy)
                    ax.axis('off')
                    ax.set_title(f'Translation=({x_shift}, {y_shift})')
                    plt.tight_layout()
                    plt.savefig(f'{opt.outf}{content}_style_{style}_translation_{x_shift}_{y_shift}.png')
                    plt.close(fig)


----------------- Options ---------------
                batchSize: 1                             
              contentPath: data/content/                 
                     cuda: True                          
              decoder_dir: models/dec_r41.pth            
                 fineSize: 256                           
                    layer: r41                           
                 loadSize: 256                           
               matrixPath: models/r41.pth                
                     outf: Artistic/translation/         
                stylePath: data/style/                   
                  vgg_dir: models/vgg_r41.pth            
----------------- End -------------------


  vgg.load_state_dict(torch.load(opt.vgg_dir))
  dec.load_state_dict(torch.load(opt.decoder_dir))
  matrix.load_state_dict(torch.load(opt.matrixPath))
Processing Styles:  71%|███████▏  | 15/21 [30:26<12:26, 124.34s/it]