In [2]:
!wget http://data.csail.mit.edu/places/places205/testSetPlaces205_resize.tar.gz
!mkdir images/
!tar -xzf drive/MyDrive/testSetPlaces205_resize.tar.gz -C 'images/'

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

from tqdm import tqdm
from network import ColorizeNet
from utils import GrayscaleImageFolder

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [5]:
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip()
])

train_set = GrayscaleImageFolder('images/', transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, num_workers=4)

In [6]:
criterion = nn.MSELoss()

model = ColorizeNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=1/3, patience=5, verbose=True)

In [7]:
try:
    checkpoint = torch.load('drive/MyDrive/checkpoint_0.001.pth', map_location=device)
    start_epoch = checkpoint['next_epoch']
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    
except:
    start_epoch = 0

In [8]:
num_epochs = 64
for epoch in range(start_epoch, num_epochs):
    train_loss = 0
    loop = tqdm(train_loader)
    for batch in loop:
        in_gray, in_ab = batch[0].to(device), batch[1].to(device)
        out_ab = model(in_gray)
        loss = criterion(out_ab, in_ab)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()*in_gray.size(0)
        loop.set_description(f'Epoch [{epoch+1:2d}/{num_epochs}]')
        loop.set_postfix(loss=train_loss)

    scheduler.step(train_loss)
    checkpoint = {
        'next_epoch': epoch + 1,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict()
    }
    torch.save(checkpoint, f'drive/MyDrive/checkpoint_{scheduler._last_lr[0]}.pth')

Epoch [39/64]: 100%|██████████| 321/321 [12:00<00:00,  2.24s/it, loss=97.4]
Epoch [40/64]: 100%|██████████| 321/321 [11:30<00:00,  2.15s/it, loss=98]
Epoch [41/64]: 100%|██████████| 321/321 [11:18<00:00,  2.12s/it, loss=97.4]
Epoch [42/64]: 100%|██████████| 321/321 [11:28<00:00,  2.15s/it, loss=103]
Epoch [43/64]: 100%|██████████| 321/321 [11:23<00:00,  2.13s/it, loss=98.4]
Epoch [44/64]: 100%|██████████| 321/321 [11:15<00:00,  2.10s/it, loss=97.3]
Epoch [45/64]: 100%|██████████| 321/321 [11:29<00:00,  2.15s/it, loss=97.6]
Epoch [46/64]: 100%|██████████| 321/321 [11:25<00:00,  2.14s/it, loss=96.5]
Epoch [47/64]: 100%|██████████| 321/321 [11:40<00:00,  2.18s/it, loss=96.7]
Epoch [48/64]: 100%|██████████| 321/321 [11:44<00:00,  2.19s/it, loss=96.3]
Epoch [49/64]: 100%|██████████| 321/321 [11:13<00:00,  2.10s/it, loss=96.5]
Epoch [50/64]: 100%|██████████| 321/321 [11:09<00:00,  2.09s/it, loss=96.4]
Epoch [51/64]: 100%|██████████| 321/321 [11:09<00:00,  2.09s/it, loss=95.7]
Epoch [52/64]: 

In [None]:
torch.save(model.state_dict(), f'./models/model.pth')