In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

from tqdm import tqdm
from network import ColorizeNet
from utils import GrayscaleImageFolder

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip()
])

train_set = GrayscaleImageFolder('images/', transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True, num_workers=1)

In [None]:
model = ColorizeNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=1/3, patience=10, verbose=True)

In [None]:
num_epochs = 64
for epoch in range(num_epochs):
    train_loss = 0
    loop = tqdm(train_loader)
    for batch in loop:
        in_gray, in_ab = batch[0].to(device), batch[1].to(device)
        out_ab = model(in_gray)
        loss = criterion(out_ab, in_ab)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()*in_gray.size(0)
        loop.set_description(f'Epoch [{epoch+1}/{num_epochs}]')
        loop.set_postfix(loss=train_loss)

    scheduler.step(train_loss)

torch.save(model.state_dict(), 'colorizenet.pth')