In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/avataranime/AvatarAnime/AnimeFace/val/19901_2007.jpg
/kaggle/input/avataranime/AvatarAnime/AnimeFace/val/47071_2014.jpg
/kaggle/input/avataranime/AvatarAnime/AnimeFace/val/823_2000.jpg
/kaggle/input/avataranime/AvatarAnime/AnimeFace/val/16836_2007.jpg
/kaggle/input/avataranime/AvatarAnime/AnimeFace/val/37175_2012.jpg
/kaggle/input/avataranime/AvatarAnime/AnimeFace/val/58761_2017.jpg
/kaggle/input/avataranime/AvatarAnime/AnimeFace/val/58395_2017.jpg
/kaggle/input/avataranime/AvatarAnime/AnimeFace/val/30919_2010.jpg
/kaggle/input/avataranime/AvatarAnime/AnimeFace/val/11352_2005.jpg
/kaggle/input/avataranime/AvatarAnime/AnimeFace/val/23582_2008.jpg
/kaggle/input/avataranime/AvatarAnime/AnimeFace/val/36898_2012.jpg
/kaggle/input/avataranime/AvatarAnime/AnimeFace/val/60582_2018.jpg
/kaggle/input/avataranime/AvatarAnime/AnimeFace/val/31736_2010.jpg
/kaggle/input/avataranime/AvatarAnime/AnimeFace/val/36882_2012.jpg
/kaggle/input/avataranime/AvatarAnime/AnimeFace/val/56217_2017.j

In [None]:
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.models import vgg19, VGG19_Weights
from PIL import Image
import os
import matplotlib.pyplot as plt
import time
from tqdm.notebook import tqdm
import timm # Thư viện chứa các mô hình pre-trained, bao gồm ViT

In [25]:
# --- 1. Thiết lập cơ bản và Hằng số ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [26]:
# --- 2. Định nghĩa Hằng số và Transforms ---
IMG_SIZE = 224 # Kích thước ảnh thống nhất
# Mean/Std của ImageNet (quan trọng khi dùng model pre-trained)
IMG_MEAN = [0.485, 0.456, 0.406]
IMG_STD = [0.229, 0.224, 0.225]

# Transforms cho huấn luyện (có augmentation)
train_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),          # Resize về kích thước cố định trước
    transforms.RandomCrop(IMG_SIZE),      # Crop ngẫu nhiên về đúng kích thước
    transforms.RandomHorizontalFlip(),    # Lật ngang ngẫu nhiên
    transforms.ToTensor(),                # Chuyển thành tensor [0, 1]
    transforms.Normalize(mean=IMG_MEAN, std=IMG_STD) # Chuẩn hóa
])

# Transforms cho validation/test (không augmentation ngẫu nhiên)
test_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),          # Resize về kích thước cố định
    transforms.CenterCrop(IMG_SIZE),      # Crop giữa về đúng kích thước
    transforms.ToTensor(),                # Chuyển thành tensor [0, 1]
    transforms.Normalize(mean=IMG_MEAN, std=IMG_STD) # Chuẩn hóa
])

# Hàm để denormalize ảnh khi hiển thị (chuyển từ tensor chuẩn hóa về ảnh xem được)
def denormalize(tensor, mean=IMG_MEAN, std=IMG_STD):
    # Clone để không ảnh hưởng tensor gốc
    tensor = tensor.clone().to('cpu')
    # Đảm bảo mean và std có đúng số chiều để broadcasting
    mean_tensor = torch.tensor(mean).view(1, 3, 1, 1)
    std_tensor = torch.tensor(std).view(1, 3, 1, 1)
    # Áp dụng công thức ngược: output = (input * std) + mean
    tensor.mul_(std_tensor).add_(mean_tensor)
    # Kẹp giá trị về [0, 1] phòng trường hợp có sai số nhỏ
    tensor = torch.clamp(tensor, 0, 1)
    return tensor

In [27]:
# --- 3. Custom Dataset ---
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        try:
            # Lấy danh sách tất cả các file ảnh hợp lệ
            self.image_files = [f for f in os.listdir(root_dir)
                                if os.path.isfile(os.path.join(root_dir, f))
                                and f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))]
            if not self.image_files:
                print(f"Warning: No valid image files found in {root_dir}")
        except FileNotFoundError:
             print(f"Error: Directory not found {root_dir}")
             self.image_files = []
        except Exception as e:
            print(f"Error reading directory {root_dir}: {e}")
            self.image_files = []


    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        if idx >= len(self.image_files):
             raise IndexError("Index out of bounds")
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        try:
            image = Image.open(img_path).convert('RGB') # Đảm bảo ảnh là RGB
        except Exception as e:
            print(f"Error loading image {img_path}: {e}. Returning black image.")
            # Trả về ảnh đen nếu lỗi load
            image = Image.new('RGB', (IMG_SIZE, IMG_SIZE), (0, 0, 0))

        if self.transform:
            image = self.transform(image)

        # Trả về ảnh và nhãn dummy (không dùng trong style transfer)
        return image, 0

In [None]:
# --- 4. Định nghĩa các thành phần Model ---

# Hàm tính Mean/Std (Không inplace)
def calc_mean_std(features, eps=1e-5):
    batch_size, num_channels = features.size()[:2]
    # .var() và .mean() không phải inplace
    feature_var = features.view(batch_size, num_channels, -1).var(dim=2) + eps
    feature_std = feature_var.sqrt().view(batch_size, num_channels, 1, 1) # .sqrt() không inplace
    feature_mean = features.view(batch_size, num_channels, -1).mean(dim=2).view(batch_size, num_channels, 1, 1)
    return feature_mean, feature_std

# Hàm AdaIN (Không inplace)
def adaptive_instance_normalization(content_feat, style_feat):
    assert content_feat.size()[:2] == style_feat.size()[:2], "Batch size and channels must match"
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)
    # Các phép toán +, -, *, / tạo tensor mới, .expand() tạo view (an toàn)
    normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)

# Decoder (Đảm bảo KHÔNG dùng inplace=True)
class Decoder(nn.Module):
    def __init__(self, input_dim, output_channels=3):
        super().__init__()
        self.model = nn.Sequential(
            # Input: B, C, 14, 14 -> Output: B, 3, 224, 224
            nn.ConvTranspose2d(input_dim, 512, kernel_size=3, stride=2, padding=1, output_padding=1), # 14->28
            nn.InstanceNorm2d(512),
            nn.ReLU(), # KHÔNG inplace=True
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), # 28->56
            nn.InstanceNorm2d(256),
            nn.ReLU(), # KHÔNG inplace=True
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1), # 56->112
            nn.InstanceNorm2d(128),
            nn.ReLU(), # KHÔNG inplace=True
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # 112->224
            nn.InstanceNorm2d(64),
            nn.ReLU(), # KHÔNG inplace=True
            # Lớp cuối cùng: Conv2d để giữ kích thước 224x224
            nn.Conv2d(64, output_channels, kernel_size=3, stride=1, padding=1),
            nn.Tanh() # Output trong khoảng [-1, 1]
        )

    def forward(self, x):
        return self.model(x)

# --- Mạng Style Transfer Chính ---
class ViTStyleTransfer(nn.Module):
    def __init__(self, vit_model_name='vit_base_patch16_224', vgg_weights=VGG19_Weights.DEFAULT):
        super().__init__()
        # 1. Encoder (ViT)
        self.vit_encoder = timm.create_model(vit_model_name, pretrained=True)
        self.embed_dim = self.vit_encoder.embed_dim
        self.vit_encoder.head = nn.Identity() # Bỏ lớp classification head
        # Lưu kích thước grid của patch embedding
        self.feature_map_size = self.vit_encoder.patch_embed.grid_size

        # 2. Decoder
        self.decoder = Decoder(input_dim=self.embed_dim)

        # 3. VGG Loss Network (Frozen)
        self.vgg_loss_net = self.load_vgg_for_loss(weights=vgg_weights)
        # Layer indices cho VGG19 features (thường dùng các lớp ReLU đầu mỗi block)
        # Ví dụ: relu1_1, relu2_1, relu3_1, relu4_1, relu5_1
        self.vgg_loss_layer_indices = [1, 6, 11, 20, 29] # Kiểm tra lại với print(vgg19(weights=...).features)

    def load_vgg_for_loss(self, weights):
        """Tải VGG19 pre-trained và đóng băng các tham số."""
        vgg = vgg19(weights=weights).features
        # Đóng băng tất cả tham số của VGG
        for param in vgg.parameters():
            param.requires_grad = False
        vgg.eval() # Đặt VGG ở chế độ eval
        print("Loaded VGG19 for loss calculation. Frozen.")
        # print(vgg) # Bỏ comment để xem cấu trúc và indices
        return vgg.to(device) # Chuyển VGG loss net sang device luôn

    def get_vit_features(self, x):
        """Trích xuất patch embeddings từ ViT và reshape."""
        B = x.shape[0]
        x = self.vit_encoder.patch_embed(x) # B, num_patches, embed_dim

        # Thêm positional embedding (bỏ qua CLS token embedding nếu có)
        # Kiểm tra xem pos_embed có bao gồm CLS token không
        cls_token_present = hasattr(self.vit_encoder, 'cls_token') and self.vit_encoder.cls_token is not None
        if cls_token_present and self.vit_encoder.pos_embed.shape[1] == x.shape[1] + 1:
            pos_embed_used = self.vit_encoder.pos_embed[:, 1:, :] # Bỏ qua pos_embed của CLS
        else:
            pos_embed_used = self.vit_encoder.pos_embed[:, :, :] # Dùng hết pos_embed (nếu không có CLS hoặc shape khớp)
            if pos_embed_used.shape[1] != x.shape[1]:
                 print(f"Warning: Positional embedding shape {pos_embed_used.shape} might not match patch embedding shape {x.shape}. Adjusting...")
                 # Cố gắng resize hoặc điều chỉnh nếu cần - phức tạp, nên kiểm tra kỹ model ViT
                 pos_embed_used = pos_embed_used[:, :x.shape[1], :] # Cách xử lý đơn giản, có thể không tối ưu

        x = x + pos_embed_used
        if hasattr(self.vit_encoder, 'pos_drop'):
             x = self.vit_encoder.pos_drop(x)

        # Đi qua các block Transformer
        for blk in self.vit_encoder.blocks:
            x = blk(x)

        x = self.vit_encoder.norm(x) # LayerNorm cuối

        # Reshape thành feature map: (B, num_patches, embed_dim) -> (B, embed_dim, H, W)
        num_patches = x.shape[1]
        H, W = self.feature_map_size
        if H * W != num_patches:
             raise ValueError(f"Feature map size {H}x{W} does not match number of patches {num_patches}")
        x = x.permute(0, 2, 1).contiguous().view(B, self.embed_dim, H, W)
        return x

    def forward(self, content_img, style_img):
        """Quy trình forward: encode -> adain -> decode."""
        content_features = self.get_vit_features(content_img)
        style_features = self.get_vit_features(style_img)
        stylized_features = adaptive_instance_normalization(content_features, style_features)
        output_img = self.decoder(stylized_features)
        return output_img

    def get_vgg_loss_features(self, image):
        """Trích xuất features từ các lớp VGG đã chọn."""
        features = []
        x = image
        for i, layer in enumerate(self.vgg_loss_net):
            x = layer(x)
            if i in self.vgg_loss_layer_indices:
                features.append(x)
        return features

    def calculate_loss(self, generated_img, content_img, style_img):
        """Tính toán content loss và style loss sử dụng VGG features."""
        mse_loss = nn.MSELoss()

        # Trích xuất features từ ảnh tạo ra (cần grad)
        gen_vgg_features = self.get_vgg_loss_features(generated_img)

        # Trích xuất features từ ảnh gốc (KHÔNG cần grad qua đây)
        # Sử dụng torch.no_grad() để đảm bảo không có gradient không mong muốn
        with torch.no_grad():
            content_vgg_features = self.get_vgg_loss_features(content_img)
            style_vgg_features = self.get_vgg_loss_features(style_img)

        # Content Loss (thường dùng 1 lớp sâu, ví dụ lớp thứ 3 hoặc 4 trong list indices)
        content_loss_target_layer_index = 2 # Ví dụ dùng relu3_1
        content_loss = mse_loss(gen_vgg_features[content_loss_target_layer_index],
                                content_vgg_features[content_loss_target_layer_index]) # Target đã trong no_grad

        # Style Loss (tính trên tất cả các lớp đã chọn)
        style_loss = 0
        for gen_feat, style_feat in zip(gen_vgg_features, style_vgg_features):
            # Tính mean/std cho features ảnh tạo ra (cần grad)
            gen_mean, gen_std = calc_mean_std(gen_feat)
            # Tính mean/std cho features ảnh style (target - không cần grad)
            style_mean, style_std = calc_mean_std(style_feat) # style_feat đã trong no_grad
            # Cộng dồn loss
            style_loss += mse_loss(gen_mean, style_mean) + mse_loss(gen_std, style_std)

        return content_loss, style_loss

In [29]:
# --- 5. Thiết lập đường dẫn và DataLoaders ---
content_dir_train = '/kaggle/input/avataranime/AvatarAnime/CelebA/train'
style_dir_train = '/kaggle/input/avataranime/AvatarAnime/AnimeFace/train'
content_dir_val = '/kaggle/input/avataranime/AvatarAnime/CelebA/val'
style_dir_val = '/kaggle/input/avataranime/AvatarAnime/AnimeFace/val'

# Kiểm tra sự tồn tại của thư mục
for path in [content_dir_train, style_dir_train, content_dir_val, style_dir_val]:
    if not os.path.isdir(path):
        print(f"FATAL ERROR: Directory not found - {path}. Please check the path.")
        # Có thể dừng script ở đây nếu muốn
        # raise FileNotFoundError(f"Directory not found: {path}")

# Tạo datasets với CustomImageDataset
print(f"Loading Content Train data from: {content_dir_train}")
content_dataset_train = CustomImageDataset(content_dir_train, transform=train_transform)
print(f"Found {len(content_dataset_train)} content train images.")

print(f"Loading Style Train data from: {style_dir_train}")
style_dataset_train = CustomImageDataset(style_dir_train, transform=train_transform)
print(f"Found {len(style_dataset_train)} style train images.")

print(f"Loading Content Validation data from: {content_dir_val}")
content_dataset_val = CustomImageDataset(content_dir_val, transform=test_transform)
print(f"Found {len(content_dataset_val)} content validation images.")

print(f"Loading Style Validation data from: {style_dir_val}")
style_dataset_val = CustomImageDataset(style_dir_val, transform=test_transform)
print(f"Found {len(style_dataset_val)} style validation images.")

# Kiểm tra dataset có dữ liệu không
if len(content_dataset_train) == 0 or len(style_dataset_train) == 0:
     raise ValueError("Training dataset(s) are empty. Check paths and image files.")
if len(content_dataset_val) == 0 or len(style_dataset_val) == 0:
     print("Warning: Validation dataset(s) are empty.")

# Tạo DataLoaders
BATCH_SIZE = 8 # Giảm nếu gặp lỗi CUDA out of memory
NUM_WORKERS = 2 # Số worker cho DataLoader
content_loader_train = DataLoader(content_dataset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
style_loader_train = DataLoader(style_dataset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
content_loader_val = DataLoader(content_dataset_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
style_loader_val = DataLoader(style_dataset_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print("DataLoaders created.")

Loading Content Train data from: /kaggle/input/avataranime/AvatarAnime/CelebA/train
Found 800 content train images.
Loading Style Train data from: /kaggle/input/avataranime/AvatarAnime/AnimeFace/train
Found 800 style train images.
Loading Content Validation data from: /kaggle/input/avataranime/AvatarAnime/CelebA/val
Found 100 content validation images.
Loading Style Validation data from: /kaggle/input/avataranime/AvatarAnime/AnimeFace/val
Found 100 style validation images.
DataLoaders created.


In [None]:
# --- 6. Khởi tạo Model, Optimizer ---
print("Initializing Model and Optimizer...")
model = ViTStyleTransfer(vit_model_name='vit_base_patch16_224').to(device) # Chuyển model gốc sang device chính trước

# KIỂM TRA VÀ SỬ DỤNG DATAPARALLEL NẾU CÓ NHIỀU GPU
if torch.cuda.device_count() > 1:
  print(f"Let's use {torch.cuda.device_count()} GPUs!")
  # device_ids có thể được chỉ định rõ ràng, ví dụ [0, 1]
  # Nếu không chỉ định, nó sẽ dùng tất cả GPU khả dụng
  model = nn.DataParallel(model)
  # Lưu ý: DataParallel sẽ tự động chuyển model.module sang các device_ids được chỉ định
  # nhưng model gốc (đã được bọc) vẫn cần được to(device) (device chính, thường là cuda:0)
  # Trong trường hợp này, model đã được to(device) trước đó.

# Optimizer vẫn khởi tạo bình thường, nhưng với model.parameters()
# Nếu model được bọc bởi DataParallel, nên dùng model.module.parameters()
# nếu muốn truy cập tham số của model gốc, tuy nhiên,
# optimizer(model.parameters()) vẫn hoạt động vì DataParallel chuyển tiếp việc truy cập này.
optimizer = optim.AdamW(
    model.parameters(), # DataParallel xử lý việc này đúng cách
    lr=1e-4
)

print("Model and Optimizer initialized.")
# Nếu model được bọc DataParallel, khi lưu và load checkpoint, cần lưu ý:
# - Khi lưu: torch.save(model.module.state_dict(), PATH)
# - Khi load: model_goc.load_state_dict(torch.load(PATH)), sau đó mới bọc DataParallel nếu cần

Initializing Model and Optimizer...
Loaded VGG19 for loss calculation. Frozen.
Let's use 2 GPUs!
Model and Optimizer initialized.


In [31]:
# prompt: remove /kaggle/working/ViT_StyleTransfer directory

import shutil
import os

# Define the directory to remove
dir_to_remove = '/kaggle/working/ViT_StyleTransfer'

# Check if the directory exists
if os.path.exists(dir_to_remove):
  try:
    # Use shutil.rmtree to remove the directory and its contents
    shutil.rmtree(dir_to_remove)
    print(f"Directory '{dir_to_remove}' and its contents removed successfully.")
  except OSError as e:
    print(f"Error: {e.filename} - {e.strerror}.")
else:
  print(f"Directory '{dir_to_remove}' does not exist.")


Directory '/kaggle/working/ViT_StyleTransfer' and its contents removed successfully.


In [None]:
# --- 7. Training Loop ---
NUM_EPOCHS = 50
LAMBDA_CONTENT = 3.0
LAMBDA_STYLE = 1.0 # Tỉ lệ style loss so với content loss

# Thư mục lưu kết quả
output_dir = '/kaggle/working/ViT_StyleTransfer' # Đổi tên để tránh ghi đè
checkpoint_dir = os.path.join(output_dir, 'checkpoints')
results_dir = os.path.join(output_dir, 'results')
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(results_dir, exist_ok=True)

# Bật anomaly detection để tìm lỗi cụ thể nếu có (có thể làm chậm quá trình huấn luyện)
# torch.autograd.set_detect_anomaly(True)

print("Starting Training...")
for epoch in range(NUM_EPOCHS):
    model.train() # Đặt model ở chế độ train
    epoch_start_time = time.time()
    running_content_loss = 0.0
    running_style_loss = 0.0
    running_total_loss = 0.0

    # Lấy số batch tối thiểu giữa 2 loader để tránh lỗi index
    num_batches = min(len(content_loader_train), len(style_loader_train))
    # Tạo iterator vô hạn cho style loader để đảm bảo luôn có style image
    style_iterator = iter(style_loader_train)

    pbar = tqdm(enumerate(content_loader_train), total=num_batches, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")

    for i, (content_batch, _) in pbar:
        if i >= num_batches: # Dừng epoch nếu đã đi qua số batch tối thiểu
             break

        # Lấy style batch, quay vòng nếu cần
        try:
            style_batch, _ = next(style_iterator)
        except StopIteration:
            style_iterator = iter(style_loader_train) # Reset iterator
            style_batch, _ = next(style_iterator)

        # Chuyển dữ liệu lên device
        content_images = content_batch.to(device)
        style_images = style_batch.to(device)

        # --- Forward Pass ---
        optimizer.zero_grad() # Xóa gradient cũ
        generated_images = model(content_images, style_images)

        # --- Loss Calculation ---
        if isinstance(model, nn.DataParallel):
            content_loss, style_loss = model.module.calculate_loss(generated_images, content_images, style_images)
        else:
            content_loss, style_loss = model.calculate_loss(generated_images, content_images, style_images)
        total_loss = LAMBDA_CONTENT * content_loss + LAMBDA_STYLE * style_loss

        # --- Backward Pass and Optimize ---
        # Kiểm tra loss có phải NaN không
        if torch.isnan(total_loss):
            print(f"Warning: NaN loss detected at batch {i}. Skipping backward pass.")
            # Cân nhắc dừng huấn luyện hoặc giảm learning rate nếu xảy ra thường xuyên
            continue # Bỏ qua batch này

        total_loss.backward() # Tính gradient
        optimizer.step()      # Cập nhật trọng số

        # --- Logging ---
        running_content_loss += content_loss.item()
        running_style_loss += style_loss.item()
        running_total_loss += total_loss.item()

        # Cập nhật progress bar
        pbar.set_postfix({
            'Content L': f"{content_loss.item():.3f}",
            'Style L': f"{style_loss.item():.3f}",
            'Total L': f"{total_loss.item():.3f}"
        })

    # --- End of Epoch ---
    avg_content_loss = running_content_loss / num_batches if num_batches > 0 else 0
    avg_style_loss = running_style_loss / num_batches if num_batches > 0 else 0
    avg_total_loss = running_total_loss / num_batches if num_batches > 0 else 0
    epoch_duration = time.time() - epoch_start_time

    print(f"\nEpoch {epoch+1} Summary: Avg Loss (Total: {avg_total_loss:.4f}, Content: {avg_content_loss:.4f}, Style: {avg_style_loss:.4f}), Time: {epoch_duration:.2f}s")

    # --- Validation and Visualization ---
    model.eval() # Đặt model ở chế độ eval
    with torch.no_grad(): # Không tính gradient trong validation
        # Lấy một vài ảnh cố định từ validation set
        try:
             val_content_batch, _ = next(iter(content_loader_val))
             val_style_batch, _ = next(iter(style_loader_val))
        except StopIteration:
             print("Warning: Validation loader is empty, skipping visualization.")
             continue # Bỏ qua visualization nếu không có dữ liệu val

        val_content = val_content_batch[:4].to(device) # Lấy 4 ảnh đầu
        val_style = val_style_batch[:4].to(device)
        val_generated = model(val_content, val_style)

        # Denormalize để hiển thị
        val_content_vis = denormalize(val_content)
        val_style_vis = denormalize(val_style)
        val_generated_vis = denormalize(val_generated)

        # Vẽ và lưu ảnh
        fig, axes = plt.subplots(3, 4, figsize=(12, 9))
        fig.suptitle(f'Epoch {epoch+1} Validation Results', fontsize=16)
        for j in range(4):
            # Content
            axes[0, j].imshow(val_content_vis[j].permute(1, 2, 0))
            axes[0, j].set_title("Content")
            axes[0, j].axis('off')
            # Style
            axes[1, j].imshow(val_style_vis[j].permute(1, 2, 0))
            axes[1, j].set_title("Style")
            axes[1, j].axis('off')
            # Generated
            axes[2, j].imshow(val_generated_vis[j].permute(1, 2, 0))
            axes[2, j].set_title("Generated")
            axes[2, j].axis('off')
        plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap
        save_path = os.path.join(results_dir, f'epoch_{epoch+1:03d}_validation.png')
        plt.savefig(save_path)
        # plt.show() # Hiển thị trong Colab output (có thể bỏ nếu chỉ cần lưu file)
        plt.close(fig) # Đóng figure để giải phóng bộ nhớ

    # --- Save Checkpoint ---
    if epoch%5 == 0:
        checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch+1:03d}.pth')
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'avg_total_loss': avg_total_loss,
            'avg_content_loss': avg_content_loss,
            'avg_style_loss': avg_style_loss,
        }, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

    # --- (Optional) Update Learning Rate Scheduler ---
    # if scheduler:
    #     scheduler.step(avg_total_loss)
    #     current_lr = optimizer.param_groups[0]['lr']
    #     print(f"Current LR: {current_lr}")

print("Training Finished!")

# Tắt anomaly detection sau khi huấn luyện xong
# torch.autograd.set_detect_anomaly(False)

Starting Training...


Epoch 1/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 1 Summary: Avg Loss (Total: 26.3829, Content: 5.8432, Style: 8.8534), Time: 80.83s
Checkpoint saved to /kaggle/working/ViT_StyleTransfer/checkpoints/model_epoch_001.pth


Epoch 2/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 2 Summary: Avg Loss (Total: 21.6025, Content: 5.2423, Style: 5.8757), Time: 80.28s


Epoch 3/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 3 Summary: Avg Loss (Total: 19.0779, Content: 4.5774, Style: 5.3457), Time: 79.90s


Epoch 4/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 4 Summary: Avg Loss (Total: 16.7014, Content: 3.8911, Style: 5.0282), Time: 79.46s


Epoch 5/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 5 Summary: Avg Loss (Total: 15.1536, Content: 3.4581, Style: 4.7794), Time: 79.76s


Epoch 6/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 6 Summary: Avg Loss (Total: 14.0609, Content: 3.1453, Style: 4.6250), Time: 79.58s
Checkpoint saved to /kaggle/working/ViT_StyleTransfer/checkpoints/model_epoch_006.pth


Epoch 7/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 7 Summary: Avg Loss (Total: 13.1504, Content: 2.8791, Style: 4.5132), Time: 79.90s


Epoch 8/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 8 Summary: Avg Loss (Total: 12.4159, Content: 2.6593, Style: 4.4380), Time: 79.80s


Epoch 9/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 9 Summary: Avg Loss (Total: 11.8817, Content: 2.4963, Style: 4.3929), Time: 79.68s


Epoch 10/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 10 Summary: Avg Loss (Total: 11.3925, Content: 2.3552, Style: 4.3271), Time: 80.19s


Epoch 11/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 11 Summary: Avg Loss (Total: 11.0439, Content: 2.2490, Style: 4.2970), Time: 80.07s
Checkpoint saved to /kaggle/working/ViT_StyleTransfer/checkpoints/model_epoch_011.pth


Epoch 12/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 12 Summary: Avg Loss (Total: 10.7078, Content: 2.1555, Style: 4.2411), Time: 80.02s


Epoch 13/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 13 Summary: Avg Loss (Total: 10.5013, Content: 2.0900, Style: 4.2313), Time: 79.75s


Epoch 14/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 14 Summary: Avg Loss (Total: 10.2392, Content: 2.0214, Style: 4.1751), Time: 80.24s


Epoch 15/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 15 Summary: Avg Loss (Total: 10.1281, Content: 1.9871, Style: 4.1669), Time: 79.52s


Epoch 16/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 16 Summary: Avg Loss (Total: 9.9379, Content: 1.9380, Style: 4.1238), Time: 80.13s
Checkpoint saved to /kaggle/working/ViT_StyleTransfer/checkpoints/model_epoch_016.pth


Epoch 17/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 17 Summary: Avg Loss (Total: 9.8083, Content: 1.8926, Style: 4.1304), Time: 79.51s


Epoch 18/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 18 Summary: Avg Loss (Total: 9.6513, Content: 1.8462, Style: 4.1127), Time: 80.28s


Epoch 19/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 19 Summary: Avg Loss (Total: 9.5532, Content: 1.8229, Style: 4.0844), Time: 80.13s


Epoch 20/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 20 Summary: Avg Loss (Total: 9.4509, Content: 1.7896, Style: 4.0821), Time: 79.73s


Epoch 21/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 21 Summary: Avg Loss (Total: 9.3987, Content: 1.7730, Style: 4.0796), Time: 79.45s
Checkpoint saved to /kaggle/working/ViT_StyleTransfer/checkpoints/model_epoch_021.pth


Epoch 22/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 22 Summary: Avg Loss (Total: 9.2942, Content: 1.7478, Style: 4.0509), Time: 79.54s


Epoch 23/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 23 Summary: Avg Loss (Total: 9.2192, Content: 1.7235, Style: 4.0486), Time: 79.70s


Epoch 24/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 24 Summary: Avg Loss (Total: 9.1631, Content: 1.7100, Style: 4.0332), Time: 79.62s


Epoch 25/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 25 Summary: Avg Loss (Total: 9.1082, Content: 1.6974, Style: 4.0160), Time: 80.25s


Epoch 26/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 26 Summary: Avg Loss (Total: 9.0852, Content: 1.6910, Style: 4.0121), Time: 79.63s
Checkpoint saved to /kaggle/working/ViT_StyleTransfer/checkpoints/model_epoch_026.pth


Epoch 27/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 27 Summary: Avg Loss (Total: 8.9495, Content: 1.6479, Style: 4.0056), Time: 79.67s


Epoch 28/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 28 Summary: Avg Loss (Total: 8.8958, Content: 1.6346, Style: 3.9918), Time: 79.57s


Epoch 29/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 29 Summary: Avg Loss (Total: 8.8312, Content: 1.6172, Style: 3.9796), Time: 79.78s


Epoch 30/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 30 Summary: Avg Loss (Total: 8.7673, Content: 1.6050, Style: 3.9521), Time: 79.59s


Epoch 31/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 31 Summary: Avg Loss (Total: 8.7344, Content: 1.5946, Style: 3.9506), Time: 80.21s
Checkpoint saved to /kaggle/working/ViT_StyleTransfer/checkpoints/model_epoch_031.pth


Epoch 32/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 32 Summary: Avg Loss (Total: 8.6828, Content: 1.5803, Style: 3.9420), Time: 79.62s


Epoch 33/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 33 Summary: Avg Loss (Total: 8.6642, Content: 1.5738, Style: 3.9429), Time: 79.57s


Epoch 34/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 34 Summary: Avg Loss (Total: 8.5899, Content: 1.5541, Style: 3.9275), Time: 80.33s


Epoch 35/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 35 Summary: Avg Loss (Total: 8.5505, Content: 1.5463, Style: 3.9117), Time: 79.55s


Epoch 36/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 36 Summary: Avg Loss (Total: 8.5127, Content: 1.5332, Style: 3.9133), Time: 79.79s
Checkpoint saved to /kaggle/working/ViT_StyleTransfer/checkpoints/model_epoch_036.pth


Epoch 37/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 37 Summary: Avg Loss (Total: 8.4803, Content: 1.5215, Style: 3.9157), Time: 79.71s


Epoch 38/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 38 Summary: Avg Loss (Total: 8.4546, Content: 1.5158, Style: 3.9071), Time: 80.26s


Epoch 39/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 39 Summary: Avg Loss (Total: 8.4030, Content: 1.5055, Style: 3.8865), Time: 79.77s


Epoch 40/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 40 Summary: Avg Loss (Total: 8.3719, Content: 1.4953, Style: 3.8859), Time: 79.74s


Epoch 41/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 41 Summary: Avg Loss (Total: 8.3310, Content: 1.4880, Style: 3.8668), Time: 79.75s
Checkpoint saved to /kaggle/working/ViT_StyleTransfer/checkpoints/model_epoch_041.pth


Epoch 42/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 42 Summary: Avg Loss (Total: 8.3174, Content: 1.4827, Style: 3.8695), Time: 80.48s


Epoch 43/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 43 Summary: Avg Loss (Total: 8.2700, Content: 1.4704, Style: 3.8589), Time: 79.89s


Epoch 44/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 44 Summary: Avg Loss (Total: 8.4786, Content: 1.5290, Style: 3.8915), Time: 79.57s


Epoch 45/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 45 Summary: Avg Loss (Total: 8.2539, Content: 1.4640, Style: 3.8618), Time: 80.71s


Epoch 46/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 46 Summary: Avg Loss (Total: 8.2127, Content: 1.4587, Style: 3.8366), Time: 79.57s
Checkpoint saved to /kaggle/working/ViT_StyleTransfer/checkpoints/model_epoch_046.pth


Epoch 47/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 47 Summary: Avg Loss (Total: 8.1494, Content: 1.4401, Style: 3.8291), Time: 79.61s


Epoch 48/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 48 Summary: Avg Loss (Total: 8.1730, Content: 1.4436, Style: 3.8421), Time: 79.69s


Epoch 49/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 49 Summary: Avg Loss (Total: 8.0825, Content: 1.4222, Style: 3.8158), Time: 80.61s


Epoch 50/50:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 50 Summary: Avg Loss (Total: 8.0908, Content: 1.4244, Style: 3.8176), Time: 79.78s
Training Finished!


In [None]:
# --- 8. Inference Function ---
def run_inference(content_path, style_path, model_checkpoint_path, output_path):
    print(f"Running inference...")
    print(f"Content: {content_path}")
    print(f"Style: {style_path}")
    print(f"Checkpoint: {model_checkpoint_path}")

    try:
        # Thêm weights_only=True để tăng cường bảo mật và tránh warning
        checkpoint = torch.load(model_checkpoint_path, map_location=device, weights_only=True)
        
        # Khởi tạo model gốc
        inference_model = ViTStyleTransfer().to(device)
        
        # Lấy state_dict từ checkpoint
        state_dict_from_checkpoint = checkpoint['model_state_dict']
        
        # Kiểm tra và loại bỏ tiền tố 'module.' nếu có
        # Tạo một state_dict mới không có tiền tố 'module.'
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        is_data_parallel_checkpoint = False
        for k, v in state_dict_from_checkpoint.items():
            if k.startswith('module.'):
                new_state_dict[k[7:]] = v # Loại bỏ 'module.' (7 ký tự)
                is_data_parallel_checkpoint = True
            else:
                new_state_dict[k] = v # Giữ nguyên nếu không có tiền tố
        
        if is_data_parallel_checkpoint:
            print("Checkpoint was saved from a DataParallel model. Removing 'module.' prefix.")
            inference_model.load_state_dict(new_state_dict)
        else:
            print("Checkpoint was saved from a single GPU model or 'module.' prefix already removed.")
            inference_model.load_state_dict(state_dict_from_checkpoint) # Load trực tiếp

        inference_model.eval() # Đặt chế độ inference
        print(f"Loaded model from epoch {checkpoint.get('epoch', 'N/A')}")

    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        return

    # Load và xử lý ảnh input
    try:
        content_img = Image.open(content_path).convert('RGB')
        style_img = Image.open(style_path).convert('RGB')
    except FileNotFoundError as e:
         print(f"Error opening image file: {e}")
         return
    except Exception as e:
         print(f"Error processing image file: {e}")
         return

    content_tensor = test_transform(content_img).unsqueeze(0).to(device)
    style_tensor = test_transform(style_img).unsqueeze(0).to(device)

    # Thực hiện style transfer
    with torch.no_grad():
        output_tensor = inference_model(content_tensor, style_tensor)

    # Denormalize và chuyển về PIL Image
    output_tensor_denorm = denormalize(output_tensor.cpu()) # Chuyển về CPU trước khi denorm
    output_img = transforms.ToPILImage()(output_tensor_denorm.squeeze(0)) # Bỏ chiều batch

    # Lưu ảnh kết quả
    try:
        output_img.save(output_path)
        print(f"Inference result saved to {output_path}")
        # display(output_img) # Hiển thị nếu chạy trong notebook
    except Exception as e:
        print(f"Error saving inference result: {e}")

# Sử dụng inference
print("\n--- Running Example Inference ---")
example_content = '/kaggle/input/avataranime/AvatarAnime/CelebA/test/019692.jpg' 
example_style = '/kaggle/input/avataranime/AvatarAnime/AnimeFace/test/2091_2001.jpg'
for i in range(1, 51):
    example_checkpoint = os.path.join(checkpoint_dir, f'model_epoch_{i:03d}.pth') 
    example_output = os.path.join(results_dir, f'inference_example_output_{i:03d}.jpg')

    if os.path.exists(example_content) and os.path.exists(example_style) and os.path.exists(example_checkpoint):
        run_inference(example_content, example_style, example_checkpoint, example_output)
    else:
        print("Skipping example inference: Input files or checkpoint not found.")



--- Running Example Inference ---
Running inference...
Content: /kaggle/input/avataranime/AvatarAnime/CelebA/test/019692.jpg
Style: /kaggle/input/avataranime/AvatarAnime/AnimeFace/test/2091_2001.jpg
Checkpoint: /kaggle/working/ViT_StyleTransfer/checkpoints/model_epoch_001.pth
Loaded VGG19 for loss calculation. Frozen.
Checkpoint was saved from a DataParallel model. Removing 'module.' prefix.
Loaded model from epoch 1
Inference result saved to /kaggle/working/ViT_StyleTransfer/results/inference_example_output_001.jpg
Skipping example inference: Input files or checkpoint not found.
Skipping example inference: Input files or checkpoint not found.
Skipping example inference: Input files or checkpoint not found.
Skipping example inference: Input files or checkpoint not found.
Running inference...
Content: /kaggle/input/avataranime/AvatarAnime/CelebA/test/019692.jpg
Style: /kaggle/input/avataranime/AvatarAnime/AnimeFace/test/2091_2001.jpg
Checkpoint: /kaggle/working/ViT_StyleTransfer/checkp

In [45]:
!zip -r ViT_StyleTransfer_results.zip /kaggle/working/ViT_StyleTransfer/results


  adding: kaggle/working/ViT_StyleTransfer/results/ (stored 0%)
  adding: kaggle/working/ViT_StyleTransfer/results/inference_example_output_031.jpg (deflated 2%)
  adding: kaggle/working/ViT_StyleTransfer/results/inference_example_output_001.jpg (deflated 14%)
  adding: kaggle/working/ViT_StyleTransfer/results/epoch_049_validation.png (deflated 0%)
  adding: kaggle/working/ViT_StyleTransfer/results/epoch_037_validation.png (deflated 0%)
  adding: kaggle/working/ViT_StyleTransfer/results/epoch_022_validation.png (deflated 0%)
  adding: kaggle/working/ViT_StyleTransfer/results/epoch_032_validation.png (deflated 0%)
  adding: kaggle/working/ViT_StyleTransfer/results/inference_example_output_011.jpg (deflated 2%)
  adding: kaggle/working/ViT_StyleTransfer/results/epoch_048_validation.png (deflated 0%)
  adding: kaggle/working/ViT_StyleTransfer/results/epoch_035_validation.png (deflated 0%)
  adding: kaggle/working/ViT_StyleTransfer/results/epoch_039_validation.png (deflated 0%)
  adding: k

In [None]:
from IPython.display import FileLink
FileLink(r'ViT_StyleTransfer_results.zip')