In [None]:
import cv2
import numpy as np
import os
import time, random
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torchvision.models import vgg16, VGG16_Weights

In [None]:

class VGG16(nn.Module):
    def __init__(self, vgg_path="/kaggle/input/vgg-pretrained/pytorch/default/1/vgg16-00b39a1b.pth"):
        super(VGG16, self).__init__()
        vgg16_features = vgg16(weights=None)
        state = torch.load(vgg_path, map_location='cpu', weights_only=False)

        vgg16_features.load_state_dict(state, strict=False)
        self.features = vgg16_features.features

        for param in self.features.parameters():
            param.requires_grad = False

    def forward(self, x):
        layers = {'3': 'relu1_2', 
                  '8': 'relu2_2', 
                  '15': 'relu3_3', 
                  '22': 'relu4_3'}
        
        features = {}
        for name, layer in self.features._modules.items():
            x = layer(x)
            if name in layers:
                features[layers[name]] = x
                if (name=='22'):
                    break

        return features 

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(c, c, 3, padding='same'),
            nn.InstanceNorm2d(c, affine=True, track_running_stats=False),
            nn.ReLU(),
            nn.Conv2d(c, c, 3, padding='same'),
            nn.InstanceNorm2d(c, affine=True, track_running_stats=False),
        )

    def forward(self, x):
        return x + self.block(x)

class TransformerNetModern(nn.Module):

    # Thay đổi giá trị tanh (càng cao càng sáng) 
    def __init__(self, tanh_multiplier=180.0):
        super().__init__()

        # Hoặc tăng số channels 
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 9, stride=1, padding='same'),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU(),
        )
        
        # Tăng số residual block ừ  5 -> 8 
        self.resblocks = nn.Sequential(*[ResidualBlock(128) for _ in range(5)])
        # THAY ĐỔI DECODER ĐỂ TRÁNH CHECKERBOARD
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d(1),
            nn.Conv2d(128, 64, 3, stride=1, padding=0),
            nn.InstanceNorm2d(64, affine=True, track_running_stats=False),
            nn.ReLU(),
            
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d(1),
            nn.Conv2d(64, 32, 3, stride=1, padding=0),
            nn.InstanceNorm2d(32, affine=True, track_running_stats=False),
            nn.ReLU(),
            
            nn.ReflectionPad2d(4),
            nn.Conv2d(32, 3, 9, stride=1, padding=0),
            nn.Tanh()
        )
        self.tanh_multiplier = tanh_multiplier

    def forward(self, x):
        out = self.decoder(self.resblocks(self.encoder(x))) * self.tanh_multiplier
        return out

In [None]:
def gram(tensor):
    B, C, H, W = tensor.shape
    x = tensor.view(B, C, H*W)
    x_t = x.transpose(1, 2)
    return  torch.bmm(x, x_t) / (C*H*W)

def load_image(path):
    # Images loaded as BGR
    img = cv2.imread(path)
    return img

def saveimg(img, image_path):
    img = img.clip(0, 255)
    cv2.imwrite(image_path, img)

def itot(img, max_size=None):
    if (max_size==None):
        itot_t = transforms.Compose([
            #transforms.ToPILImage(),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))
        ])    
    else:
        H, W, C = img.shape
        image_size = tuple([int((float(max_size) / max([H,W]))*x) for x in [H, W]])
        itot_t = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))
        ])

    tensor = itot_t(img)
    tensor = tensor.unsqueeze(dim=0)
    return tensor

def ttoi(tensor):
    # Remove the batch_size dimension
    tensor = tensor.squeeze()
    #img = ttoi_t(tensor)
    img = tensor.cpu().numpy()
    
    # Transpose from [C, H, W] -> [H, W, C]
    img = img.transpose(1, 2, 0)
    return img

def plot_loss_hist(c_loss, s_loss, total_loss, title="Loss History", save_dir="/kaggle/working/"):
    os.makedirs(save_dir, exist_ok=True)
    x = [i for i in range(len(total_loss))]
    plt.figure(figsize=[10, 6])
    plt.plot(x, c_loss, label="Content Loss")
    plt.plot(x, s_loss, label="Style Loss")
    plt.plot(x, total_loss, label="Total Loss")
    plt.legend()
    plt.xlabel('Every 500 iterations')
    plt.ylabel('Loss')
    plt.title(title)
    plt.grid(True, linestyle='--', alpha=0.6)  

    save_path = os.path.join(save_dir, f"{title.replace(' ', '_').lower()}.png")
    plt.savefig(save_path, bbox_inches='tight')
    plt.show()
    plt.close()

In [None]:
TRAIN_IMAGE_SIZE = 256
DATASET_PATH = "/kaggle/input/dataset"
NUM_EPOCHS = 4
STYLE_IMAGE_PATH = "/kaggle/input/styleimage/la_muse.jpg"
BATCH_SIZE = 8 
CONTENT_WEIGHT = 17 
STYLE_WEIGHT = 50 
ADAM_LR = 0.001
SAVE_MODEL_PATH = "/kaggle/working/models/"
SAVE_IMAGE_PATH = "/kaggle/working/images/"
SAVE_MODEL_EVERY = 1000
SEED = 35
PLOT_LOSS = 1

def train():
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)
    os.makedirs(SAVE_MODEL_PATH, exist_ok=True)
    os.makedirs(SAVE_IMAGE_PATH, exist_ok=True)

    device = ("cuda" if torch.cuda.is_available() else "cpu")
    transform = transforms.Compose([
        transforms.Resize(TRAIN_IMAGE_SIZE),
        transforms.CenterCrop(TRAIN_IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(DATASET_PATH, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    TransformerNetwork = TransformerNetModern().to(device)
    VGG = VGG16().to(device)

    MSELoss = nn.MSELoss().to(device)
    optimizer = optim.Adam(TransformerNetwork.parameters(), lr=ADAM_LR)

    imagenet_neg_mean = torch.tensor(
        [-103.939, -116.779, -123.68], 
        dtype=torch.float32).reshape(1,3,1,1).to(device)
    
    style_image = load_image(STYLE_IMAGE_PATH)
    style_tensor = itot(style_image).to(device)
     
    with torch.no_grad(): 
        style_tensor_norm = style_tensor.add(imagenet_neg_mean) # Norm hóa
        
        # Chỉ đưa tensor style với kích thước [1, C, H, W] vào VGG
        style_features_single = VGG(style_tensor_norm)
    
    # Tính Gram Matrix và lưu trữ
    style_gram = {}
    for key, value in style_features_single.items():
        # Gram Matrix đã được tính từ batch size 1
        style_gram[key] = gram(value)

    
    content_loss_history = []
    style_loss_history = []
    total_loss_history = []
    batch_content_loss_sum = 0
    batch_style_loss_sum = 0
    batch_total_loss_sum = 0

    batch_count = 1
    start_time = time.time()
    for epoch in range(NUM_EPOCHS):
        print("========Epoch {}/{}========".format(epoch+1, NUM_EPOCHS))
        for content_batch, _ in train_loader:
            curr_batch_size = content_batch.shape[0]
            torch.cuda.empty_cache()
            optimizer.zero_grad()

            content_batch = content_batch[:,[2,1,0]].to(device)
            generated_batch = TransformerNetwork(content_batch)
            content_features = VGG(content_batch.add(imagenet_neg_mean))
            generated_features = VGG(generated_batch.add(imagenet_neg_mean))

            content_loss = CONTENT_WEIGHT * MSELoss(generated_features['relu2_2'], content_features['relu2_2'])            
            batch_content_loss_sum += content_loss.item()

            style_loss = 0.0
            for key, value in generated_features.items():
                
                style_target = style_gram[key].expand(curr_batch_size, -1, -1)
                
                s_loss = MSELoss(gram(value), style_target)
                style_loss += s_loss
            style_loss *= STYLE_WEIGHT

            
            batch_style_loss_sum += style_loss.item()

            total_loss = content_loss + style_loss
            batch_total_loss_sum += total_loss.item()

            total_loss.backward()
            optimizer.step()

            if (((batch_count-1)%SAVE_MODEL_EVERY == 0) or (batch_count==NUM_EPOCHS*len(train_loader))):
                print("========Iteration {}/{}========".format(batch_count, NUM_EPOCHS*len(train_loader)))
                print("\tContent Loss:\t{:.2f}".format(batch_content_loss_sum/batch_count))
                print("\tStyle Loss:\t{:.2f}".format(batch_style_loss_sum/batch_count))
                print("\tTotal Loss:\t{:.2f}".format(batch_total_loss_sum/batch_count))
                print("Time elapsed:\t{} seconds".format(time.time()-start_time))

                # Save Model
                checkpoint_path = os.path.join(
                    SAVE_MODEL_PATH, f"checkpoint_{batch_count-1}.pth"
                )
                torch.save(TransformerNetwork.state_dict(), checkpoint_path)
                print("Saved TransformerNetwork checkpoint file at {}".format(checkpoint_path))

                # Save sample generated image
                sample_tensor = generated_batch[0].clone().detach().unsqueeze(dim=0)
                sample_image = ttoi(sample_tensor.clone().detach())
                sample_image_path = os.path.join(
                    SAVE_IMAGE_PATH, f"sample0_{batch_count-1}.png"
                )
                saveimg(sample_image, sample_image_path)
                print("Saved sample tranformed image at {}".format(sample_image_path))

                content_loss_history.append(batch_content_loss_sum/batch_count)
                style_loss_history.append(batch_style_loss_sum/batch_count)
                total_loss_history.append(batch_total_loss_sum/batch_count)

            batch_count+=1

    stop_time = time.time()
    print("Done Training the Transformer Network!")
    print("Training Time: {} seconds".format(stop_time-start_time))
    print("========Content Loss========")
    print(content_loss_history) 
    print("========Style Loss========")
    print(style_loss_history) 
    print("========Total Loss========")
    print(total_loss_history) 

    TransformerNetwork.eval()
    TransformerNetwork.cpu()
    final_path = os.path.join(SAVE_MODEL_PATH, "transformer_weight.pth")
    print("Saving TransformerNetwork weights at {}".format(final_path))
    torch.save(TransformerNetwork.state_dict(), final_path)
    print("Done saving final model")

    if (PLOT_LOSS):
        plot_loss_hist(content_loss_history, style_loss_history, total_loss_history)

In [None]:
 
def display_image(img, title="Stylized Image"):
    """
    Hiển thị ảnh (NumPy array [H, W, C]) trực tiếp trong Notebook.
    """
    if img.max() > 1.0:
        img = img / 255.0 
        
    img = img[:, :, [2, 1, 0]] 

    plt.figure(figsize=(8, 8))
    plt.imshow(img)
    plt.title(title)
    plt.axis('off')
    plt.show()

In [None]:
train()

In [None]:
def stylize_image(content_image_path, output_image_path=None, display_only=True): 
    """
    Stylize một ảnh đầu vào bằng mô hình đã train
    
    Args:
        content_image_path: Đường dẫn ảnh input
        output_image_path: Tên file output (optional)
        display_only: Nếu True, chỉ hiển thị không lưu file
    """
    
    device = ("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    model = TransformerNetModern().to(device)
    
    FINAL_MODEL_PATH = "/kaggle/input/transform-eight/pytorch/default/1/rabbit_weight.pth"
    
    print("\n" + "="*50)
    print(f"Loading model weights from:\n{FINAL_MODEL_PATH}")
    
    try:
        state_dict = torch.load(FINAL_MODEL_PATH, map_location=device, weights_only=False)
        
        
        model.load_state_dict(state_dict, strict=True)
        print("Model weights loaded successfully!")
        
    except FileNotFoundError:
        print("ERROR: Model weights not found!")
        print(f"Looking for: {FINAL_MODEL_PATH}")
        return None
    except Exception as e:
        print(f"ERROR loading model: {e}")
        return None

    model.eval() 
    print("Model set to evaluation mode")
    
    print(f"\nLoading content image from:\n{content_image_path}")
    
    try:
        content_image = load_image(content_image_path)
        print(f"Original image shape: {content_image.shape}")
    except Exception as e:
        print(f" ERROR loading image: {e}")
        return None
    
    content_tensor = itot(content_image).to(device)
    
    content_tensor = content_tensor[:, [2, 1, 0], :, :]
    
    print(f"Input tensor shape: {content_tensor.shape}")
    
    print("\n" + "="*50)
    print("Running style transfer...")
    
    with torch.no_grad():
        start_time = time.time()
        stylized_tensor = model(content_tensor)
        end_time = time.time()
    
    print(f"Style transfer completed in {end_time - start_time:.4f} seconds")
    print(f"Output tensor shape: {stylized_tensor.shape}")

    stylized_image = ttoi(stylized_tensor.detach())

    stylized_image = stylized_image.clip(0, 255).astype(np.uint8)
    
    print("\n" + "="*50)
    print("Displaying results...")
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 8))
    
    axes[0].imshow(cv2.cvtColor(content_image, cv2.COLOR_BGR2RGB))
    axes[0].set_title('Original Content Image', fontsize=14, fontweight='bold')
    axes[0].axis('off')
    
    axes[1].imshow(cv2.cvtColor(stylized_image.astype(np.uint8), cv2.COLOR_BGR2RGB))
    axes[1].set_title('Stylized Image', fontsize=14, fontweight='bold')
    axes[1].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    if not display_only and output_image_path:
        os.makedirs(SAVE_IMAGE_PATH, exist_ok=True)
        final_output_path = os.path.join(SAVE_IMAGE_PATH, output_image_path)
        saveimg(stylized_image, final_output_path)
        print(f"Stylized image saved to: {final_output_path}")
    
    print("="*50)
    
    return stylized_image


In [None]:
# Stylize nhiều ảnh cùng lúc
def stylize_batch(image_paths):
    """Stylize nhiều ảnh và hiển thị grid"""
    results = []
    
    for img_path in image_paths:
        print(f"\n{'='*60}")
        print(f"Processing: {img_path}")
        print('='*60)
        result = stylize_image(img_path, display_only=True)
        if result is not None:
            results.append(result)
    
    # Hiển thị tất cả kết quả trong một grid
    if results:
        n_images = len(results)
        cols = min(3, n_images)
        rows = (n_images + cols - 1) // cols
        
        fig, axes = plt.subplots(rows, cols, figsize=(cols*5, rows*5))
        axes = axes.flatten() if n_images > 1 else [axes]
        
        for idx, img in enumerate(results):
            axes[idx].imshow(cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB))
            axes[idx].set_title(f'Stylized {idx+1}', fontsize=12)
            axes[idx].axis('off')
        
        # Hide unused subplots
        for idx in range(n_images, len(axes)):
            axes[idx].axis('off')
        
        plt.tight_layout()
        plt.show()


In [None]:
image_list = [
     "/kaggle/input/test-images/baseball.jpg",
     "/kaggle/input/test-images/Anh-chan-dung-nam.jpg",
     "/kaggle/input/test-images/children.jpg",
    "/kaggle/input/test-images/cycling.jpg",
    "/kaggle/input/test-images/Lion.jpg",
    "/kaggle/input/test-images/the-gate.jpg",
    "/kaggle/input/test-images/hoover.jpg"
 ]
stylize_batch(image_list)
