In [1]:
import torch
import torch.nn.functional as F

In [2]:
torch.manual_seed(42)

<torch._C.Generator at 0x2cefeeed4d0>

# Input

In [3]:
# INPUTS: 1 channel, 3x3 images (batch size = 1)
content_img = torch.tensor(
    [[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]], requires_grad=False
)

style_img = torch.tensor(
    [[[[9.0, 8.0, 7.0], [6.0, 5.0, 4.0], [3.0, 2.0, 1.0]]]], requires_grad=False
)

output_img = torch.zeros_like(content_img, requires_grad=True)

In [4]:
print("=== INPUT IMAGES ===")
print("Content Image:\n", content_img[0, 0])
print("Style Image:\n", style_img[0, 0])
print("Output Image (init):\n", output_img.detach()[0, 0])

=== INPUT IMAGES ===
Content Image:
 tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])
Style Image:
 tensor([[9., 8., 7.],
        [6., 5., 4.],
        [3., 2., 1.]])
Output Image (init):
 tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])


# Model

In [5]:
# 1 input channel, 1 output channel, kernel size = 2x2
conv_weight = torch.tensor(
    [[[[1.0, 0.0], [0.0, -1.0]]]], requires_grad=False
)  # shape [1, 1, 2, 2]

In [6]:
def simple_cnn(x, weight):
    return F.conv2d(x, weight, stride=1, padding=0)

In [7]:
# Extract feature maps
F_content = simple_cnn(content_img, conv_weight)
F_style = simple_cnn(style_img, conv_weight)
F_output = simple_cnn(output_img, conv_weight)

In [8]:
print("\n=== FEATURE MAPS ===")
print("F_content:\n", F_content[0, 0])
print("F_style:\n", F_style[0, 0])
print("F_output:\n", F_output[0, 0])


=== FEATURE MAPS ===
F_content:
 tensor([[-4., -4.],
        [-4., -4.]])
F_style:
 tensor([[4., 4.],
        [4., 4.]])
F_output:
 tensor([[0., 0.],
        [0., 0.]], grad_fn=<SelectBackward0>)


# Loss and Gram matrix

In [9]:
def gram_matrix(x):
    b, c, h, w = x.shape
    features = x.view(c, h * w)
    return torch.mm(features, features.t())  # shape: [c, c]

In [10]:
def content_loss(F_target, F_content):
    return F.mse_loss(F_target, F_content)


def style_loss(F_target, F_style):
    G_target = gram_matrix(F_target)
    G_style = gram_matrix(F_style)
    print(f"{G_target=}, {G_style=}")
    return F.mse_loss(G_target, G_style)

In [11]:
# Tính loss
c_loss = content_loss(F_output, F_content)
s_loss = style_loss(F_output, F_style)
total_loss = c_loss + s_loss

G_target=tensor([[0.]], grad_fn=<MmBackward0>), G_style=tensor([[64.]])


In [12]:
print("\n=== LOSSES ===")
print(f"Content Loss: {c_loss.item():.4f}")
print(f"Style Loss: {s_loss.item():.4f}")
print(f"Total Loss: {total_loss.item():.4f}")


=== LOSSES ===
Content Loss: 16.0000
Style Loss: 4096.0000
Total Loss: 4112.0000


In [13]:
# Backward
total_loss.backward()

# Learning rate
lr = 0.01
with torch.no_grad():
    output_img -= lr * output_img.grad
    output_img.grad.zero_()

In [14]:
print("\n=== GRADIENTS ===")
print("Grad of output_img:\n", output_img.grad[0, 0])

print("\n=== OUTPUT IMAGE AFTER 1 UPDATE ===")
print(output_img.detach()[0, 0])


=== GRADIENTS ===
Grad of output_img:
 tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

=== OUTPUT IMAGE AFTER 1 UPDATE ===
tensor([[-0.0200, -0.0200,  0.0000],
        [-0.0200,  0.0000,  0.0200],
        [ 0.0000,  0.0200,  0.0200]])
