In [1]:
import numpy as np
import time
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchvision
from torchvision import transforms
import tqdm

def get_computing_device():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')
    return device

device = get_computing_device()
print(f"Our main computing device is '{device}'")

Our main computing device is 'cuda:0'


In [65]:
def compute_loss(predictions, gt):
    return F.binary_cross_entropy_with_logits(predictions, gt.float()).mean()

def eval_model(model, data_generator):
    accuracy = []
    model.train(False)

    with torch.no_grad():
        for X_batch, y_batch in tqdm.tqdm(data_generator):
            X_batch = X_batch.to(device)
            logits = torch.sigmoid(model(X_batch))
            y_pred = logits.round().data
            accuracy.append(np.mean((y_batch[:,11:12].cpu() == y_pred.cpu()).numpy()))
            
    return np.mean(accuracy)

            
def train_model(model, optimizer, train_data_generator):
    train_loss = []
    model.train(True)
    i = 0
    for (X_batch, y_batch) in tqdm.tqdm(train_data_generator):
        optimizer.zero_grad()

        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device) 
        predictions = model(X_batch)
        loss = compute_loss(predictions, y_batch[:,11:12])
        loss.backward()
        # backward
        optimizer.step()

        # metrics
        train_loss.append(loss.cpu().data.numpy())
        i += 1
        if i > 100: break
    return np.mean(train_loss)


def train_loop(model, optimizer, train_data_generator, val_data_generator, num_epochs):
    """
    num_epochs - total amount of full passes over training data
    """
    for epoch in range(num_epochs):
        start_time = time.time()
        
        train_loss = train_model(model, optimizer, train_data_generator)
        
        val_accuracy = eval_model(model, val_data_generator)

        print("Epoch {} of {} took {:.3f}s".format(epoch + 1, num_epochs, time.time() - start_time))
        print("  training loss (in-iteration): \t{:.6f}".format(train_loss))
        print("  validation accuracy: \t\t\t{:.2f} %".format(val_accuracy * 100))

In [60]:
root = 'celeba'

class CropCelebA64:
    
    def __call__(self, pic):
        new_pic = pic.crop((15, 40, 178 - 15, 218 - 30))
        return new_pic

    def __repr__(self):
        return self.__class__.__name__ + '()'

train_dataset = torchvision.datasets.CelebA(
    root=root,
    split='train',
    transform=torchvision.transforms.Compose([
        CropCelebA64(),
        torchvision.transforms.Resize(64),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor()
    ]),
)

validation_dataset = torchvision.datasets.CelebA(
    root=root,
    split='valid',
    transform=torchvision.transforms.Compose([
        CropCelebA64(),
        torchvision.transforms.Resize(64),
        torchvision.transforms.ToTensor()
    ]),
)

In [61]:
batch_size = 8
train_batch_gen = torch.utils.data.DataLoader(train_dataset, 
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=2)

val_batch_gen = torch.utils.data.DataLoader(validation_dataset, 
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=2)

In [66]:
from unet3plus import UNet3Plus

class Decoder(torch.nn.Module):
    def __init__(self, d=64, num_out=1):
        super().__init__()
        self.conv = nn.Conv2d(d, 128, 3, 2, 1)
        self.pool = nn.MaxPool2d(32)
        self.fc = nn.Linear(128, num_out)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.pool(x)
        x = x.view(x.shape[0], -1)
        return self.fc(x)
    

class Net(torch.nn.Module):
    def __init__(self, n_channels=3, d=64, bilinear=True, feature_scale=4,
                  is_deconv=True, is_batchnorm=True):
        super().__init__()
        self.encoder = UNet3Plus()
        self.decoder = Decoder(d=d, num_out=1)
    
    def forward(self, x):
        x = self.encoder(x)
        return self.decoder(x)
    
model = Net().to(device)

In [68]:
opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
train_loop(model, opt, train_batch_gen, val_batch_gen, num_epochs=4)

  0%|          | 100/20347 [00:47<2:39:57,  2.11it/s]
  4%|▍         | 100/2484 [00:10<04:19,  9.17it/s]


Epoch 1 of 4 took 58.312s
  training loss (in-iteration): 	0.455782
  validation accuracy: 			75.50 %


  0%|          | 100/20347 [00:47<2:40:14,  2.11it/s]
  4%|▍         | 100/2484 [00:10<04:18,  9.22it/s]


Epoch 2 of 4 took 58.341s
  training loss (in-iteration): 	0.473256
  validation accuracy: 			67.45 %


  0%|          | 100/20347 [00:47<2:40:14,  2.11it/s]
  4%|▍         | 100/2484 [00:10<04:20,  9.15it/s]


Epoch 3 of 4 took 58.424s
  training loss (in-iteration): 	0.440144
  validation accuracy: 			75.50 %


  0%|          | 100/20347 [00:47<2:40:57,  2.10it/s]
  4%|▍         | 100/2484 [00:10<04:19,  9.20it/s]

Epoch 4 of 4 took 58.573s
  training loss (in-iteration): 	0.484641
  validation accuracy: 			75.50 %



