In [2]:
import os
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
#為 train Data 建立 Dataset 讀取方式 
class CustomImageDataset(Dataset):
    def __init__(self, img_in_dir, img_label_dir, transform=None): 
        self.img_in_dir = img_in_dir
        self.img_label_dir = img_label_dir
        self.transform = transform
        self.images = os.listdir(img_in_dir) 

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_in_path = os.path.join(self.img_in_dir, self.images[idx]) 
        img_label_path = os.path.join(self.img_label_dir, self.images[idx])
        image_in = Image.open(img_in_path).convert("RGB") 
        image_label = Image.open(img_label_path).convert("RGB")

        if self.transform:
            image_in = self.transform(image_in) 
            image_label = self.transform(image_label)
        return image_in, image_label
        
img_input = './dataset/3D Rendered Cartoon Style/train/images' 
img_label = './dataset/3D Rendered Cartoon Style/train/labels'
transform = transforms.Compose([
     transforms.Resize((224, 224)),
    transforms.ToTensor()
])
trainset = CustomImageDataset(img_in_dir=img_input, img_label_dir=img_label, transform=transform)
# 建立 trainloader
trainloader = DataLoader(trainset, batch_size=16, shuffle=True, num_workers=0)

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, num_styles=3):
        super(UNet, self).__init__()
        self.num_styles = num_styles

        # === Encoder ===
        self.encoder1 = nn.Sequential(
            nn.Conv2d(3 + num_styles, 64, kernel_size=3, padding=1),  # 3 RGB + style one-hot
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool1 = nn.MaxPool2d(2, 2)

        self.encoder2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool2 = nn.MaxPool2d(2, 2)

        self.encoder3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool3 = nn.MaxPool2d(2, 2)

        self.middle = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        # === Decoder ===
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x, style_id):
        # style_id: tensor of shape [B] or [B, 1] with values 0, 1, or 2
        # Convert style to one-hot: [B, 3]
        if style_id.dim() == 1:
            style_id = style_id.unsqueeze(1)
        style_onehot = F.one_hot(style_id, num_classes=self.num_styles).float()  # [B, 1, 3]
        style_onehot = style_onehot.squeeze(1)

        # Expand to [B, 3, H, W]
        B, _, H, W = x.shape
        #style_map = style_onehot.unsqueeze(2).unsqueeze(3).expand(B, self.num_styles, H, W)
        style_map = style_onehot.unsqueeze(2).unsqueeze(3).repeat(1, 1, H, W)

        # Concatenate to input
        x = torch.cat([x, style_map], dim=1)  # input becomes [B, 6, H, W]

        # Standard UNet forward
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        mid = self.middle(self.pool3(enc3))
        dec3 = self.decoder3(torch.cat([self.upconv3(mid), enc3], dim=1))
        dec2 = self.decoder2(torch.cat([self.upconv2(dec3), enc2], dim=1))
        dec1 = self.decoder1(torch.cat([self.upconv1(dec2), enc1], dim=1))
        return dec1


In [None]:
import torch.optim as optim

#device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
device = torch.device('cuda' if torch.backends.mps.is_available() else 'cpu')

model = UNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 50

for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in trainloader:
        images, labels = images.to(device, dtype=torch.float32), labels.to(device, dtype=torch.float32)
        style_id = torch.full((images.size(0),), 2, dtype=torch.long, device=device)

        optimizer.zero_grad()
        outputs = model(images, style_id)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    avg_loss = running_loss / len(trainloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.7f}')

print('Finished Training')
torch.save(model.state_dict(), 'styled_unet_model.pth')
print('Model saved as styled_unet_model.pth')

mps
Epoch [1/50], Loss: 0.0297306
Epoch [2/50], Loss: 0.0197056
Epoch [3/50], Loss: 0.0186904
Epoch [4/50], Loss: 0.0184432
Epoch [5/50], Loss: 0.0184920
Epoch [6/50], Loss: 0.0184308
Epoch [7/50], Loss: 0.0181614
Epoch [8/50], Loss: 0.0181555
Epoch [9/50], Loss: 0.0174354
Epoch [10/50], Loss: 0.0156544
Epoch [11/50], Loss: 0.0151087
Epoch [12/50], Loss: 0.0147580
Epoch [13/50], Loss: 0.0142678
Epoch [14/50], Loss: 0.0139481
Epoch [15/50], Loss: 0.0137170
Epoch [16/50], Loss: 0.0134296
Epoch [17/50], Loss: 0.0131833
Epoch [18/50], Loss: 0.0126863
Epoch [19/50], Loss: 0.0126434
Epoch [20/50], Loss: 0.0122865
Epoch [21/50], Loss: 0.0119951
Epoch [22/50], Loss: 0.0116827
Epoch [23/50], Loss: 0.0116268
Epoch [24/50], Loss: 0.0113672
Epoch [25/50], Loss: 0.0111814
Epoch [26/50], Loss: 0.0108342
Epoch [27/50], Loss: 0.0107045
Epoch [28/50], Loss: 0.0104727
Epoch [29/50], Loss: 0.0102284
Epoch [30/50], Loss: 0.0099634
Epoch [31/50], Loss: 0.0096565
Epoch [32/50], Loss: 0.0093060
Epoch [33/50]

In [18]:
import random
from torchvision.utils import save_image

model.eval()

output_folder = './test'
input_folder = './dataset/Comic Style/train/input' 

image_paths = []
for f in os.listdir(input_folder):
    path = os.path.join(input_folder, f)
    image_paths.append(path)

selected_paths = random.sample(image_paths, 50)

for i, path in enumerate(selected_paths):
    img = Image.open(path).convert("RGB")
    size = img.size
    input_tensor = transform(img).unsqueeze(0).to(device)  # [1, 3, H, W]

    with torch.no_grad():
        output = model(input_tensor, style_id)
        #output = output.clamp(0, 1)  # Ensure pixel range [0, 1]
        output_resized = F.interpolate(output, size=size[::-1], mode='bilinear', align_corners=False)

    save_image(output_resized, os.path.join(output_folder, f"styled_3d_{i+1}.jpg"))
