In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F


# Losses


### Content Loss- Squared distance between feature representations, basically just nn.MSELoss but fit for my workflow

In [35]:
class ContentLoss(nn.Module):
    def __init__(self, target):
        super().__init__()
        self.register_buffer('target', target)

    def forward(self, input):
        return nn.MSELoss()(input, self.target)


In [44]:
# Testing content Loss
target_map = torch.randn(1, 3, 4, 4)
pruning_map = torch.randn(1, 3, 4, 4)

# Instantiate ContentLoss with the target feature map
content_loss = ContentLoss(target_map)

# Compute the cnotent loss for the generated image
loss = content_loss(pruning_map)

print("Conent Loss:", loss.item())


Conent Loss: 1.3877025842666626


### Style Loss- Squared distance betwen Gram Matrices

In [22]:
# 4d Tensor -> Gram Matrix
class GramMatrix(nn.Module):
    def forward(self, v):
        # Flatten
        v_f = v.flatten(-2)
        # Transpose (switch last two layers)
        v_f_t = v_f.transpose(-2, -1)
        # Matrix multiplication
        v_mul = v_f @ v_f_t
        # Normalize
        gram = v_mul / (v_mul.shape[0] * v_mul.shape[1])
        return gram

class StyleLoss(nn.Module):
    # Register target gram matrix for reuse
    def __init__(self, target_gram, eps=1e-8):
        super().__init__()
        self.register_buffer('target_gram', target_gram)

    # Forward pass- Gram Matrix distance
    def forward(self, input):
        return nn.MSELoss()(GramMatrix()(input), self.target_gram)

    

In [45]:
# Testing Style Loss
style_feature_map = torch.randn(1, 3, 4, 4)
generated_feature_map = torch.randn(1, 3, 4, 4)

# Compute the target Gram matrix from the style image feature map
target_gram = GramMatrix()(style_feature_map)

# Instantiate StyleLoss with the target Gram matrix
style_loss = StyleLoss(target_gram)

# Compute the style loss for the generated image
loss = style_loss(generated_feature_map)

print("Style Loss:", loss.item())



Style Loss: 8.624126434326172


Content Loss