In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F


# Losses


### Content Loss- Squared distance between feature representations, basically just nn.MSELoss but fit for my workflow

In [6]:
class ContentLoss(nn.Module):
    def __init__(self, target):
        super().__init__()
        self.register_buffer('target', target)

    def forward(self, input):
        return nn.MSELoss()(input, self.target)


In [11]:
# Testing content Loss
target_map = torch.randn(1, 3, 224, 224)
pruning_map = torch.randn(1, 3, 224, 224)

# Instantiate ContentLoss with the target feature map
content_loss = ContentLoss(target_map)

# Compute the cnotent loss for the generated image
loss = content_loss(pruning_map)

print("Conent Loss:", loss.item())
type(loss)

Conent Loss: 1.990310549736023


torch.Tensor

### Style Loss- Squared distance betwen Gram Matrices

In [42]:
# 4d Tensor -> Gram Matrix
class GramMatrix(nn.Module):
    def forward(self, v):
        # Flatten
        v_f = v.flatten(-2)
        # Transpose (switch last two layers)
        v_f_t = v_f.transpose(-2, -1)
        # Matrix multiplication
        v_mul = v_f @ v_f_t
        # Normalize
        gram = v_mul / (v_mul.shape[0] * v_mul.shape[1])
        return gram

# class StyleLoss(nn.Module):
#     # Register target gram matrix for reuse
#     def __init__(self, target_gram, eps=1e-8):
#         super().__init__()
#         self.eps =eps
#         self.register_buffer('target_gram', target_gram)
                


#     # Forward pass- Gram Matrix distance
#     def forward(self, input):
#         input_gram = GramMatrix()(input)

#         # Calculate the number of elements in the input Gram matrix
#         # Adding eps for numerical stability
#         num_elements = input_gram.nelement() + self.eps
#         print(num_elements)
        
        
#         return nn.MSELoss()(input_gram, self.target_gram) / (input.shape[-1] * input.shape[-2])
# Testing this new style loss
class StyleLoss(nn.Module):
    def __init__(self, target_gram):
        super(StyleLoss, self).__init__()
        self.target = target_gram

    def forward(self, G, input):

        self.loss = nn.functional.mse_loss(G, self.target, reduction='sum')
        N = input.size(0)
        M = input.size(1) * input.size(2)  # Height times width of the feature map.
        self.loss /= (4 * (N ** 2) * (M ** 2))
        return self.loss

    

In [38]:
# Testing Style Loss
style_feature_map = torch.randn(1, 3, 224, 224)
generated_feature_map = torch.randn(1, 3, 224, 224)

# Compute the target Gram matrix from the style image feature map
target_gram = GramMatrix()(style_feature_map)

# Instantiate StyleLoss with the target Gram matrix
style_loss = StyleLoss(target_gram)

# Compute the style loss for the generated image
loss = style_loss(generated_feature_map)

print("Style Loss:", loss.item())



9.00000001
Style Loss: 0.32466599345207214


Content Loss

In [14]:
import math

math.prod(style_feature_map.shape)

150528

In [15]:
style_feature_map.nelement()

150528

# TV Loss

In [39]:
class TVLoss(nn.Module):
    def forward(self, input):
        x_diff = input[..., :-1, :-1] - input[..., :-1, 1:]
        y_diff = input[..., :-1, :-1] - input[..., 1:, :-1]
        diff = x_diff**2 + y_diff**2
        return torch.sum(diff)
