In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

In [2]:
# Download and unzip (2.2GB)
# !wget http://data.csail.mit.edu/places/places205/testSetPlaces205_resize.tar.gz
# !mkdir images/
# !tar -xzf drive/MyDrive/testSetPlaces205_resize.tar.gz -C 'images/'

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

from tqdm import tqdm
from network import ColorizeNet
from utils import GrayscaleImageFolder

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [5]:
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip()
])

train_set = GrayscaleImageFolder('images/', transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, num_workers=4)

In [6]:
criterion = nn.MSELoss()

model = ColorizeNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.003)
# scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[17, 47], gamma=1/3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=1/3, patience=5, verbose=True)

In [7]:
try:
    checkpoint = torch.load('models/model_0.001.pth', map_location=device)
    start_epoch = checkpoint['next_epoch']
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    
except:
    start_epoch = 0

Training process:

| start_epoch | end_epoch | learning_rate |
| -: | -: | -: |
| 0 | 18 | 0.03 |
| 19 | 48 | 0.01 |
| 49 | 62 | 0.003 |
| 63 | 68 | 0.001 |

In [8]:
num_epochs = 69
for epoch in range(start_epoch, num_epochs):
    train_loss = 0
    loop = tqdm(train_loader)
    for batch in loop:
        in_gray, in_ab = batch[0].to(device), batch[1].to(device)
        out_ab = model(in_gray)
        loss = criterion(out_ab, in_ab)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()*in_gray.size(0)
        loop.set_description(f'Epoch [{epoch+1:2d}/{num_epochs}]')
        loop.set_postfix(loss=train_loss)

    scheduler.step(train_loss)
    checkpoint = {
        'next_epoch': epoch + 1,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict()
    }
    torch.save(checkpoint, f'models/model_{scheduler._last_lr[0]}.pth')

Epoch [56/69]: 100%|██████████| 321/321 [10:43<00:00,  2.01s/it, loss=99.6]
Epoch [57/69]: 100%|██████████| 321/321 [10:40<00:00,  1.99s/it, loss=99.1]
Epoch [58/69]: 100%|██████████| 321/321 [10:41<00:00,  2.00s/it, loss=99.3]
Epoch [59/69]: 100%|██████████| 321/321 [10:57<00:00,  2.05s/it, loss=99.3]
Epoch [60/69]: 100%|██████████| 321/321 [11:42<00:00,  2.19s/it, loss=99.3]
Epoch [61/69]: 100%|██████████| 321/321 [11:55<00:00,  2.23s/it, loss=99.4]
Epoch [62/69]: 100%|██████████| 321/321 [11:57<00:00,  2.23s/it, loss=99.1]
Epoch [63/69]: 100%|██████████| 321/321 [12:03<00:00,  2.25s/it, loss=99.1]
  0%|          | 0/321 [00:00<?, ?it/s]

Epoch    14: reducing learning rate of group 0 to 1.0000e-03.


Epoch [64/69]: 100%|██████████| 321/321 [11:54<00:00,  2.23s/it, loss=98.2]
Epoch [65/69]: 100%|██████████| 321/321 [11:21<00:00,  2.12s/it, loss=98.3]
Epoch [66/69]: 100%|██████████| 321/321 [11:11<00:00,  2.09s/it, loss=98.2]
Epoch [67/69]: 100%|██████████| 321/321 [11:17<00:00,  2.11s/it, loss=98.3]
Epoch [68/69]: 100%|██████████| 321/321 [11:31<00:00,  2.16s/it, loss=98.3]
Epoch [69/69]: 100%|██████████| 321/321 [11:15<00:00,  2.10s/it, loss=98.1]
