In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
transform = transforms.Compose([
    transforms.ToTensor()
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)

100%|████████████████████████████████████████████████████████████████████████████████| 170M/170M [08:31<00:00, 333kB/s]


In [9]:
def rgb_to_grayscale(tensor):
    gray = 0.2989 * tensor[:, 0:1, :, :] + 0.5870 * tensor[:, 1:2, :, :] + 0.1140 * tensor[:, 2:3, :, :]
    return gray

In [10]:
def generate_random_hints(gray, color, num_points=10):
    B, _, H, W = gray.size()
    hint_rgb = torch.zeros_like(color)
    hint_mask = torch.zeros((B, 1, H, W), dtype=torch.float32)

    for b in range(B):
        for _ in range(num_points):
            y = torch.randint(0, H, (1,))
            x = torch.randint(0, W, (1,))
            hint_rgb[b, :, y, x] = color[b, :, y, x]
            hint_mask[b, 0, y, x] = 1.0

    return hint_rgb, hint_mask

In [11]:
class HintColorizationNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(5, 64, 3, padding=1),  # 1(gray) + 3(hint_rgb) + 1(hint_mask)
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, 3, padding=1),
            nn.Tanh()
        )

    def forward(self, gray, hint_rgb, hint_mask):
        x = torch.cat([gray, hint_rgb, hint_mask], dim=1)
        features = self.encoder(x)
        output = self.decoder(features)
        return output

In [12]:
model = HintColorizationNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

In [13]:
for epoch in range(1):
    for i, (color, _) in enumerate(trainloader):
        color = color.to(device)
        gray = rgb_to_grayscale(color).to(device)

        hint_rgb, hint_mask = generate_random_hints(gray, color)
        hint_rgb = hint_rgb.to(device)
        hint_mask = hint_mask.to(device)

        output = model(gray, hint_rgb, hint_mask)
        loss = criterion(output, color)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f"Epoch [{epoch+1}], Step [{i}], Loss: {loss.item():.4f}")

Epoch [1], Step [0], Loss: 0.2748
Epoch [1], Step [100], Loss: 0.0082
Epoch [1], Step [200], Loss: 0.0050
Epoch [1], Step [300], Loss: 0.0057
Epoch [1], Step [400], Loss: 0.0068
Epoch [1], Step [500], Loss: 0.0045
Epoch [1], Step [600], Loss: 0.0048
Epoch [1], Step [700], Loss: 0.0042
Epoch [1], Step [800], Loss: 0.0041
Epoch [1], Step [900], Loss: 0.0045
Epoch [1], Step [1000], Loss: 0.0027
Epoch [1], Step [1100], Loss: 0.0033
Epoch [1], Step [1200], Loss: 0.0047
Epoch [1], Step [1300], Loss: 0.0040
Epoch [1], Step [1400], Loss: 0.0069
Epoch [1], Step [1500], Loss: 0.0034


In [14]:
from sklearn.metrics import confusion_matrix, precision_score, recall_score
import torch
import numpy as np

all_preds = []
all_labels = []

model.eval()
with torch.no_grad():
    for gray_img, true_color in test_loader: 
        gray_img = gray_img.to(device)
        true_color = true_color.to(device)

        hint_rgb, hint_mask = generate_random_hints(gray_img, true_color)

        output = model(gray_img, hint_rgb.to(device), hint_mask.to(device))

        pred_label = torch.argmax(output, dim=1).flatten().cpu().numpy()
        true_label = torch.argmax(true_color, dim=1).flatten().cpu().numpy()

        all_preds.extend(pred_label)
        all_labels.extend(true_label)


conf_matrix = confusion_matrix(all_labels, all_preds)
precision = precision_score(all_labels, all_preds, average='macro')
recall = recall_score(all_labels, all_preds, average='macro')

print("Confusion Matrix:\n", conf_matrix)
print("Precision (macro):", precision)
print("Recall (macro):", recall)


ValueError: Classification metrics can't handle a mix of multiclass and unknown targets

In [None]:

torch.save(model.state_dict(), 'colorization_model.pth')

torch.save(model, 'colorization_model_full.pth')

In [None]:
model.eval()
with torch.no_grad():
    sample_color = next(iter(trainloader))[0][:1].to(device)
    sample_gray = rgb_to_grayscale(sample_color)
    hint_rgb, hint_mask = generate_random_hints(sample_gray, sample_color)
    hint_rgb = hint_rgb.to(device)
    hint_mask = hint_mask.to(device)

    pred = model(sample_gray.to(device), hint_rgb, hint_mask)


In [None]:
import matplotlib.pyplot as plt  

def show_image(tensor, title=""):
    img = tensor.cpu().squeeze()
    if img.dim() == 2:
        # Grayscale image
        plt.imshow(img.numpy(), cmap="gray")
    elif img.dim() == 3:
        # Color image [C, H, W]
        img = img.permute(1, 2, 0).numpy()
        plt.imshow(img)
    else:
        raise ValueError("Unsupported tensor shape for visualization")
    
    plt.title(title)
    plt.axis("off")


In [None]:
import torch

sample_gray = torch.rand(1, 64, 64)        
hint_rgb = torch.rand(3, 64, 64)          
pred = torch.rand(3, 64, 64)                
sample_color = torch.rand(3, 64, 64)       


In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 4, 1); show_image(sample_gray, "Gray")
plt.subplot(1, 4, 2); show_image(hint_rgb, "Hint")
plt.subplot(1, 4, 3); show_image(pred, "Predicted")
plt.subplot(1, 4, 4); show_image(sample_color, "Original")
plt.show()


In [None]:
import tkinter as tk
from tkinter import filedialog
from PIL import Image, ImageTk
import numpy as np

def colorize_image():
    # Load image
    file_path = filedialog.askopenfilename()
    if not file_path:
        return
    
    # Load and preprocess image
    img = Image.open(file_path).convert('RGB').resize((32, 32))
    img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(device)
    gray = rgb_to_grayscale(img_tensor)
    hint_rgb, hint_mask = generate_random_hints(gray, img_tensor)
    
    # Colorize
    model.eval()
    with torch.no_grad():
        output = model(gray.to(device), hint_rgb.to(device), hint_mask.to(device))
    
    # Convert to PIL image
    output = output.squeeze(0).cpu().numpy().transpose(1, 2, 0)
    output = (output * 255).astype(np.uint8)
    output_img = Image.fromarray(output)
    
    # Display
    output_img = output_img.resize((256, 256))
    output_tk = ImageTk.PhotoImage(output_img)
    output_label.config(image=output_tk)
    output_label.image = output_tk

# Create GUI
root = tk.Tk()
root.title('Image Colorization')

load_button = tk.Button(root, text="Load Image", command=colorize_image)
load_button.pack()

output_label = tk.Label(root)
output_label.pack()

root.mainloop()