In [1]:
import math
from torch import nn
from torch.autograd import Variable
import torch
import torch.nn.functional as F
import torchvision
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
from PIL import Image
import os
import matplotlib.pyplot as plt
import time
from torchsummary import summary
import config
from facenet_pytorch import training
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from PIL import Image
import glob
import torchvision.models as models
from util import AverageMeter, learning_rate_decay, Logger

In [2]:
# Root directory for dataset
data_root = "/home/mehmetyavuz/datasets/CelebA128/"
attr_root = "/home/mehmetyavuz/datasets/list_attr_celeba.txt"
# Number of workers for dataloader
workers = 4

# Batch size during training
batch_size = 64

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = (128,128)
epochs = 10

In [3]:
class CelebA(data.Dataset):
    def __init__(self, data_path, attr_path, image_size, mode, selected_attrs):
        super(CelebA, self).__init__()
        self.data_path = data_path
        att_list = open(attr_path, 'r', encoding='utf-8').readlines()[1].split()
        atts = [att_list.index(att) + 1 for att in selected_attrs]
        images = np.loadtxt(attr_path, skiprows=2, usecols=[0], dtype=np.str)
        labels = np.loadtxt(attr_path, skiprows=2, usecols=atts, dtype=np.int)
        
        self.tf = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
        self.tf_a = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([
                transforms.ColorJitter(hue=.05, saturation=.05),
            ], p=0.8),
            transforms.RandomGrayscale(0.2),
        ])        
        if mode == 'train':
            self.images = images[:162770]
            self.labels = labels[:162770]

        if mode == 'valid':
            self.images = images[162770:182637]
            self.labels = labels[162770:182637]

        if mode == 'test':
            self.images = images[182637:]
            self.labels = labels[182637:]
                                       
        self.length = len(self.images)
    def __getitem__(self, index):
        if index < 162770:
            img = self.tf(self.tf_a(Image.open(os.path.join(self.data_path, self.images[index]))))
        else:
            img = self.tf(Image.open(os.path.join(self.data_path, self.images[index])))
        att = torch.tensor((self.labels[index] + 1) // 2)
        return img, att.to(torch.float32)
    def __len__(self):
        return self.length

In [4]:
attrs_default = ["5_o_Clock_Shadow", "Arched_Eyebrows", "Attractive", "Bags_Under_Eyes", "Bald", "Bangs", "Big_Lips", "Big_Nose", 
                 "Black_Hair", "Blond_Hair", "Blurry", "Brown_Hair", "Bushy_Eyebrows", "Chubby", "Double_Chin", "Eyeglasses", "Goatee", 
                 "Gray_Hair", "Heavy_Makeup", "High_Cheekbones", "Male", "Mouth_Slightly_Open", "Mustache", "Narrow_Eyes", "No_Beard", 
                 "Oval_Face", "Pale_Skin", "Pointy_Nose", "Receding_Hairline", "Rosy_Cheeks", "Sideburns", "Smiling", "Straight_Hair", 
                 "Wavy_Hair", "Wearing_Earrings", "Wearing_Hat", "Wearing_Lipstick", "Wearing_Necklace", "Wearing_Necktie", "Young"]

In [5]:
train_dataset = CelebA(data_root, attr_root, image_size, 'train', attrs_default)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=workers)
dataset = CelebA(data_root, attr_root, image_size, 'valid', attrs_default)
val_loader = torch.utils.data.DataLoader(dataset,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          num_workers=workers)
dataset = CelebA(data_root, attr_root, image_size, 'test', attrs_default)
test_loader = torch.utils.data.DataLoader(dataset,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          num_workers=workers)

In [6]:
# Decide which device we want to run on
device = torch.device("cuda:0")

In [7]:
resnet = models.alexnet(pretrained=True).to(device)
resnet.classifier[6] = nn.Linear(in_features=4096, out_features=40, bias=True).to(device)

In [9]:
resnet

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [12]:
epoch = 0
for i, (name, param) in enumerate(resnet.named_parameters()):
    j = i // 2
    if epoch % 2 == 0:
        if j % 2 == 0:
            if 'weight' in name:
                param.requires_grad = False
            if 'bias' in name:
                param.requires_grad = False
        else:
            if 'weight' in name:
                param.requires_grad = True
            if 'bias' in name:
                param.requires_grad = True       
    else:
        if j % 2 == 0:        
            if 'weight' in name:
                param.requires_grad = True
            if 'bias' in name:
                param.requires_grad = True
        else:
            if 'weight' in name:
                param.requires_grad = False
            if 'bias' in name:
                param.requires_grad = False 

In [16]:
for i, (name, param) in enumerate(resnet.named_parameters()):
    if param.requires_grad:
        print(name)

features.3.weight
features.3.bias
features.8.weight
features.8.bias
classifier.1.weight
classifier.1.bias
classifier.6.weight
classifier.6.bias


In [None]:
for i, (name, param) in enumerate(resnet.named_parameters()):
    if param.requires_grad:
        print(name)

In [None]:
optimizer = optim.Adam(resnet.parameters(), lr=0.00001)
#Q = math.floor(len(train_dataset)*epochs)
scheduler = None #torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, Q)

In [None]:
loss_fn = torch.nn.BCEWithLogitsLoss()
metrics = {
    'acc': training.accuracy_ml
} 

In [None]:
print('\n\nInitial')
print('-' * 10)
resnet.eval()
training.pass_epoch(
    resnet, loss_fn, test_loader,
    batch_metrics=metrics, show_running=True, device=device,
    #writer=writer
)

val_loss = 1
for epoch in range(epochs):
    print('\nEpoch {}/{}'.format(epoch + 1, epochs))
    print('-' * 10)

    resnet.train()
    training.pass_epoch(
        resnet, loss_fn, train_loader, optimizer, scheduler,
        batch_metrics=metrics, show_running=True, device=device,
        #writer=writer
    )

    resnet.eval()
    val_metrics = training.pass_epoch(
        resnet, loss_fn, val_loader,
        batch_metrics=metrics, show_running=True, device=device,
        #writer=writer
    )
    
    if val_metrics[0].item() < val_loss:
        val_loss = val_metrics[0].item()
        print('Test set Accuracy Lowest Validation Loss:')
        training.pass_epoch(
            resnet, loss_fn, test_loader,
            batch_metrics=metrics, show_running=True, device=device,
            #writer=writer
        )
        torch.save(resnet.state_dict(), "alexnet.pth")

#writer.close()