In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub
steubk_wikiart_path = kagglehub.dataset_download('steubk/wikiart')

print('Data source import complete.')


Downloading from https://www.kaggle.com/api/v1/datasets/download/steubk/wikiart?dataset_version_number=1...


100%|██████████| 31.4G/31.4G [05:48<00:00, 96.7MB/s]

Extracting files...





Data source import complete.


In [None]:
# Install required libraries in KaggleHub
!pip install torch torchvision pandas matplotlib tqdm

# Optional: Mount Google Drive for persistent storage (if linked with Colab)
# Uncomment the following lines if you want to save artifacts to Google Drive
# from google.colab import drive
# drive.mount('/content/drive')

# Import necessary libraries
import os
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
import glob

# Set device to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Create directory for saving artifacts in KaggleHub
artifacts_dir = '/kaggle/working/artifacts'
os.makedirs(artifacts_dir, exist_ok=True)

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

# **PRETRAITEMENT**

In [None]:
# Configuration
DATA_DIR = steubk_wikiart_path  
IMAGE_SIZE = 512
BATCH_SIZE = 8   
MAX_SAMPLES = 1000  

# normalize to [-1 , 1]
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),  # [0, 1]
    transforms.Lambda(lambda x: x * 2 - 1)  # [-1, 1]
])

#customize DS
class LimitedWikiArtDataset(Dataset):
    def __init__(self, data_dir, max_samples=MAX_SAMPLES):
        self.data_dir = data_dir
        #ensure dirs exists
        if not os.path.exists(data_dir):
            raise ValueError(f'Dataset directory {data_dir} does not exist.')

        self.image_paths = []
        for ext in ['jpg', 'jpeg', 'png', 'JPG']:
            self.image_paths.extend(glob.glob(os.path.join(data_dir, f'**/*.{ext}'), recursive=True))

        
        self.image_paths = sorted(list(set(self.image_paths)))

        print(f'Found {len(self.image_paths)} images in {data_dir}')
        if len(self.image_paths) > 0:
            print(f'Example image paths: {self.image_paths[:3]}')

        if not self.image_paths:
            raise ValueError(f'No images found in {data_dir}. Ensure the dataset contains .jpg, .jpeg, or .png files.')

        self.image_paths = self.image_paths[:min(max_samples, len(self.image_paths))]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            img = Image.open(img_path).convert('RGB')
            img = self.transform(img)
            return {'image': img}
        except Exception as e:
            print(f'Error loading image {img_path}: {e}')
            return {'image': torch.zeros(3, IMAGE_SIZE, IMAGE_SIZE)}

#load sample DS
try:
    print('Loading dataset from', DATA_DIR)
    dataset = LimitedWikiArtDataset(data_dir=DATA_DIR, max_samples=MAX_SAMPLES)
    print(f'Dataset loaded successfully with {len(dataset)} images')
except Exception as e:
    print(f'Error loading dataset: {e}')
    raise

# create dataholder for DS
def collate_fn(examples):
    images = []
    for example in examples:
        img = example['image']
        # Skip dummy images
        if img.sum() == 0:
            continue
        images.append(img)
    if not images:
        raise ValueError('No valid images in batch')
    return torch.stack(images)

# Wrap dataset in a DataLoader for batching
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0, shuffle=True)

# Test loading a batch
#print('Testing DataLoader...')
#for i, batch in enumerate(tqdm(dataloader, desc='Testing DataLoader', total=10)):
#    print(f'Batch {i+1} shape: {batch.shape}')  # Expected: (batch_size, 3, 128, 128)
#    batch = batch.to(device)  # Move to GPU
#    print(f'Batch moved to {device}')
#    if i >= 9:  # Stop after 10 batches
#        break

# Check GPU memory usage
!nvidia-smi

Loading dataset from /root/.cache/kagglehub/datasets/steubk/wikiart/versions/1
Found 81444 images in /root/.cache/kagglehub/datasets/steubk/wikiart/versions/1
Example image paths: ['/root/.cache/kagglehub/datasets/steubk/wikiart/versions/1/Abstract_Expressionism/aaron-siskind_acolman-1-1955.jpg', '/root/.cache/kagglehub/datasets/steubk/wikiart/versions/1/Abstract_Expressionism/aaron-siskind_chicago-1951.jpg', '/root/.cache/kagglehub/datasets/steubk/wikiart/versions/1/Abstract_Expressionism/aaron-siskind_chicago-6-1961.jpg']
Dataset loaded successfully with 1000 images
Fri May  2 17:48:30 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf       

In [None]:
#load and preprocess style image
#note: uload georges-seurat_sunday-afternoon-on-the-island-of-la-grande-jatte.jpg to /kaggle/input/georges-seurat_sunday-afternoon-on-the-island-of-la-grande-jatte/ or /kaggle/working/
style_img_path = './style-image.jpg'
if not os.path.exists(style_img_path):
    style_img_path = '/kaggle/working/air-terjun.jpg'
    if not os.path.exists(style_img_path):
        print('Style image not found. Please upload georges-seurat_sunday-afternoon-on-the-island-of-la-grande-jatte.jpg to KaggleHub.')
        try:
            from google.colab import files
            uploaded = files.upload()
            style_img_path = list(uploaded.keys())[0]  # Use uploaded file
        except ImportError:
            print('Colab file upload not available in KaggleHub. Please upload georges-seurat_sunday-afternoon-on-the-island-of-la-grande-jatte.jpg manually to /kaggle/working/')
            raise FileNotFoundError('Style image not found')

try:
    style_img = Image.open(style_img_path).convert('RGB')
    style_tensor = transform(style_img).unsqueeze(0).to(device)
    print(f'Style image tensor min: {style_tensor.min()}, max: {style_tensor.max()}')  # Should be [-1, 1]
except Exception as e:
    print(f'Error loading style image: {e}')
    raise

Style image tensor min: -0.9529411792755127, max: 0.8980392217636108


In [None]:
# transformer-NETWORK
class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = kernel_size // 2
        self.layer = nn.Sequential(
            nn.ReflectionPad2d(reflection_padding),
            nn.Conv2d(in_channels, out_channels, kernel_size, stride),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.ReLU()
        )

    def forward(self, x):
        return self.layer(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            ConvLayer(channels, channels, 3, 1),
            ConvLayer(channels, channels, 3, 1)
        )

    def forward(self, x):
        return x + self.block(x)

class TransformerNet(nn.Module):
    def __init__(self):
        super(TransformerNet, self).__init__()
        self.encoder = nn.Sequential(
            ConvLayer(3, 32, 9, 1),
            ConvLayer(32, 64, 3, 2),
            ConvLayer(64, 128, 3, 2),
        )
        self.residuals = nn.Sequential(
            ResidualBlock(128),
            ResidualBlock(128),
            ResidualBlock(128),
            ResidualBlock(128),
            ResidualBlock(128),
        )
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2),
            ConvLayer(128, 64, 3, 1),
            nn.Upsample(scale_factor=2),
            ConvLayer(64, 32, 3, 1),
            ConvLayer(32, 3, 9, 1),
            nn.Tanh(),  # [-1, 1] output
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.residuals(x)
        x = self.decoder(x)
        return x

# Instantiate and move to device
transformer = TransformerNet().to(device)

# **2 : Définir la perte (Loss)**

In [None]:
# VGG features for content and style loss
class VGGFeatures(nn.Module):
    def __init__(self):
        super(VGGFeatures, self).__init__()
        vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
        self.relu1_2 = nn.Sequential(*vgg[:4])
        self.relu2_2 = nn.Sequential(*vgg[4:9])
        self.relu3_3 = nn.Sequential(*vgg[9:16])
        self.relu4_3 = nn.Sequential(*vgg[16:23])
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        out = {}
        out['relu1_2'] = self.relu1_2(x)
        out['relu2_2'] = self.relu2_2(out['relu1_2'])
        out['relu3_3'] = self.relu3_3(out['relu2_2'])
        out['relu4_3'] = self.relu4_3(out['relu3_3'])
        return out

# move VGG isntance todevice
vgg = VGGFeatures().to(device).eval()

#gram_matrix for style loss
def gram_matrix(tensor):
    b, c, h, w = tensor.size()
    features = tensor.view(b, c, h * w)
    G = torch.bmm(features, features.transpose(1, 2))
    return G / (c * h * w)

#compute content loss / style loss
def compute_losses(content_img, style_img, stylized_img, vgg):
    content_features = vgg(content_img)
    style_features = vgg(style_img)
    stylized_features = vgg(stylized_img)

    # content_loss (using relu4_3)
    content_loss = F.mse_loss(stylized_features['relu4_3'], content_features['relu4_3'])

    # style_loss (using multiple layers)
    style_loss = 0
    style_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3']
    for layer in style_layers:
        stylized_gram = gram_matrix(stylized_features[layer])
        style_gram = gram_matrix(style_features[layer])
        style_loss += F.mse_loss(stylized_gram, style_gram)

    return content_loss, style_loss

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:07<00:00, 76.9MB/s]


# **3 : Boucle d'entraînement**

In [None]:
# Checkpoint saving and loading
def save_checkpoint(transformer, optimizer, epoch, path=os.path.join(artifacts_dir, 'checkpoint.pth')):
    torch.save({
        'epoch': epoch,
        'model_state_dict': transformer.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, path)

def load_checkpoint(transformer, optimizer, path=os.path.join(artifacts_dir, 'checkpoint.pth')):
    if os.path.isfile(path):
        checkpoint = torch.load(path, map_location=device)
        transformer.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f'✅ Checkpoint loaded, resuming from epoch {start_epoch}')
        return start_epoch
    else:
        print('ℹ️ No checkpoint found, training from scratch')
        return 0

#traning looop
def train(transformer, dataloader, style_tensor, vgg, epochs=50, content_weight=1e5, style_weight=1e10):
    optimizer = torch.optim.Adam(transformer.parameters(), lr=1e-3)
    start_epoch = load_checkpoint(transformer, optimizer)

    # Lists to store losses for plotting
    content_losses = []
    style_losses = []
    total_losses = []

    transformer.train()
    for epoch in range(start_epoch, epochs):
        epoch_content_loss = 0
        epoch_style_loss = 0
        epoch_total_loss = 0
        batch_count = 0

        total_batches = (MAX_SAMPLES + BATCH_SIZE - 1) // BATCH_SIZE

        for i, batch in enumerate(tqdm(dataloader, desc=f'Epoch {epoch+1}/{epochs}', total=total_batches)):
            try:
                batch = batch.to(device)
                optimizer.zero_grad()

                stylized = transformer(batch)

                content_loss, style_loss = compute_losses(batch, style_tensor, stylized, vgg)
                total_loss = content_weight * content_loss + style_weight * style_loss
                total_loss.backward()
                optimizer.step()

                epoch_content_loss += content_loss.item()
                epoch_style_loss += style_loss.item()
                epoch_total_loss += total_loss.item()
                batch_count += 1

                # Save sample stylized image every 50 iterations
                if i % 50 == 0:
                    with torch.no_grad():
                        img = denormalize(stylized[0]).permute(1, 2, 0).cpu().numpy()
                        plt.imsave(os.path.join(artifacts_dir, f'stylized_epoch{epoch}_iter{i}.png'), img)

            except Exception as e:
                print(f'Error in training batch {i}: {e}')
                continue

        if batch_count > 0:
            content_losses.append(epoch_content_loss / batch_count)
            style_losses.append(epoch_style_loss / batch_count)
            total_losses.append(epoch_total_loss / batch_count)

            print(f'Epoch {epoch+1}, Content Loss: {content_losses[-1]:.4f}, Style Loss: {style_losses[-1]:.4f}, Total Loss: {total_losses[-1]:.4f}')

        save_checkpoint(transformer, optimizer, epoch)

    plt.figure(figsize=(10, 5))
    plt.plot(content_losses, label='Content Loss')
    plt.plot(style_losses, label='Style Loss')
    plt.plot(total_losses, label='Total Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Losses')
    plt.legend()
    plt.savefig(os.path.join(artifacts_dir, 'loss_plot.png'))
    plt.close()

    return content_losses, style_losses, total_losses

def denormalize(img_tensor):
    img = (img_tensor + 1) / 2  # Inverse of x*2 - 1
    img = img.clamp(0, 1)
    return img

# Run training
content_losses, style_losses, total_losses = train(transformer, dataloader, style_tensor, vgg)

ℹ️ No checkpoint found, training from scratch


  style_loss += F.mse_loss(stylized_gram, style_gram)
  style_loss += F.mse_loss(stylized_gram, style_gram)
  style_loss += F.mse_loss(stylized_gram, style_gram)
  style_loss += F.mse_loss(stylized_gram, style_gram)
Epoch 1/50: 100%|██████████| 125/125 [04:36<00:00,  2.21s/it]


Epoch 1, Content Loss: 3.8468, Style Loss: 0.0000, Total Loss: 638019.8692


Epoch 2/50: 100%|██████████| 125/125 [04:38<00:00,  2.23s/it]


Epoch 2, Content Loss: 3.1908, Style Loss: 0.0000, Total Loss: 438264.0557


Epoch 3/50: 100%|██████████| 125/125 [04:39<00:00,  2.23s/it]


Epoch 3, Content Loss: 2.9130, Style Loss: 0.0000, Total Loss: 394361.3285


Epoch 4/50: 100%|██████████| 125/125 [04:38<00:00,  2.23s/it]


Epoch 4, Content Loss: 2.6980, Style Loss: 0.0000, Total Loss: 363479.7100


Epoch 5/50: 100%|██████████| 125/125 [04:38<00:00,  2.23s/it]


Epoch 5, Content Loss: 2.6991, Style Loss: 0.0000, Total Loss: 368758.4228


Epoch 6/50: 100%|██████████| 125/125 [04:39<00:00,  2.24s/it]


Epoch 6, Content Loss: 2.5006, Style Loss: 0.0000, Total Loss: 340456.6252


Epoch 7/50: 100%|██████████| 125/125 [04:38<00:00,  2.23s/it]


Epoch 7, Content Loss: 2.3831, Style Loss: 0.0000, Total Loss: 325098.7413


Epoch 8/50: 100%|██████████| 125/125 [04:37<00:00,  2.22s/it]


Epoch 8, Content Loss: 2.2803, Style Loss: 0.0000, Total Loss: 313434.5109


Epoch 9/50: 100%|██████████| 125/125 [04:33<00:00,  2.19s/it]


Epoch 9, Content Loss: 2.2219, Style Loss: 0.0000, Total Loss: 306339.9771


Epoch 10/50: 100%|██████████| 125/125 [04:33<00:00,  2.19s/it]


Epoch 10, Content Loss: 2.1517, Style Loss: 0.0000, Total Loss: 298530.9477


Epoch 11/50: 100%|██████████| 125/125 [04:33<00:00,  2.19s/it]


Epoch 11, Content Loss: 2.0662, Style Loss: 0.0000, Total Loss: 288508.1506


Epoch 12/50: 100%|██████████| 125/125 [04:39<00:00,  2.24s/it]


Epoch 12, Content Loss: 2.0753, Style Loss: 0.0000, Total Loss: 290392.4671


Epoch 13/50: 100%|██████████| 125/125 [04:38<00:00,  2.23s/it]


Epoch 13, Content Loss: 2.0634, Style Loss: 0.0000, Total Loss: 288754.6302


Epoch 14/50: 100%|██████████| 125/125 [04:38<00:00,  2.23s/it]


Epoch 14, Content Loss: 1.9622, Style Loss: 0.0000, Total Loss: 277144.2679


Epoch 15/50: 100%|██████████| 125/125 [04:38<00:00,  2.23s/it]


Epoch 15, Content Loss: 1.9365, Style Loss: 0.0000, Total Loss: 274128.1550


Epoch 16/50: 100%|██████████| 125/125 [04:39<00:00,  2.23s/it]


Epoch 16, Content Loss: 1.9132, Style Loss: 0.0000, Total Loss: 271911.0236


Epoch 17/50: 100%|██████████| 125/125 [04:38<00:00,  2.23s/it]


Epoch 17, Content Loss: 1.8548, Style Loss: 0.0000, Total Loss: 264752.5995


Epoch 18/50: 100%|██████████| 125/125 [04:39<00:00,  2.23s/it]


Epoch 18, Content Loss: 1.9008, Style Loss: 0.0000, Total Loss: 270446.2793


Epoch 19/50:  46%|████▋     | 58/125 [02:09<02:30,  2.25s/it]

In [None]:
# Visualize sample images from the dataset
dataiter = iter(dataloader)
try:
    images = next(dataiter)
    fig, axes = plt.subplots(1, min(BATCH_SIZE, 4), figsize=(10, 3))
    for idx in range(min(BATCH_SIZE, 4)):
        img = denormalize(images[idx]).permute(1, 2, 0).cpu().numpy()
        axes[idx].imshow(img)
        axes[idx].axis('off')
    plt.savefig(os.path.join(artifacts_dir, 'sample_images.png'))
    plt.close()

    # Visualize a stylized image
    transformer.eval()
    with torch.no_grad():
        stylized = transformer(images[:1].to(device))
        img = denormalize(stylized[0]).permute(1, 2, 0).cpu().numpy()
        plt.imshow(img)
        plt.axis('off')
        plt.savefig(os.path.join(artifacts_dir, 'final_stylized.png'))
        plt.close()
except Exception as e:
    print(f'Error visualizing images: {e}')

# Save loss data to CSV for report
loss_df = pd.DataFrame({
    'Epoch': range(1, len(content_losses) + 1),
    'Content Loss': content_losses,
    'Style Loss': style_losses,
    'Total Loss': total_losses
})
loss_df.to_csv(os.path.join(artifacts_dir, 'losses.csv'), index=False)

# Check GPU memory usage after training
!nvidia-smi

Fri May  2 16:37:25 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   77C    P0             34W /   70W |    2018MiB /  15360MiB |      9%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                