# face2age

In [18]:
import torch
import torch.nn
from tqdm import tqdm
import torch.nn as nn
from pathlib import Path
import torch.optim as optim
import torchvision.models as models
from image_dataset import ImageDataset
from sklearn.model_selection import train_test_split
from torchvision.transforms import ToPILImage
from torch.utils.data import DataLoader, Subset

### State paths

In [19]:
DATA_DIR = Path('data')
TRAIN_IMGS = DATA_DIR / 'wiki_labeled'
TEST_IMGS = DATA_DIR / 'wiki_judge_images'

labels_file = DATA_DIR / 'wiki_labels.csv'
judge_ids_file = DATA_DIR / 'wiki_judge.csv'

## Define Network

In [26]:
# Leverage pretrained VGG16 to extract the first set of features
class VGG16FeatureExtractor(nn.Module):
    def __init__(self):
        super(VGG16FeatureExtractor, self).__init__()
        # Load pre-trained VGG-16 model
        self.vgg16 = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
        
        # Freeze parameters
        for param in self.vgg16.parameters():
            param.requires_grad = False
        
    def forward(self, x):
        out = self.vgg16(x)
        
        return torch.FloatTensor(out)
    

class AgePredictor(nn.Module):
    def __init__(self):
        super(AgePredictor, self).__init__()
        
        self.flat = nn.Flatten()
        self.l0 = nn.Linear(1000, 300)
        self.l1 = nn.Linear(300, 100)
        self.l2 = nn.Linear(100, 1)
        
        self.body = nn.Sequential(
            self.flat,
            self.l0,
            nn.ReLU(inplace=True),
            self.l1,
            nn.ReLU(inplace=True),
            self.l2,
            nn.Sigmoid()
        )
    
    def forward(self, x):
        out = self.body(x)
        return out
    
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        self.vgg = VGG16FeatureExtractor()
        self.age_predictor = AgePredictor()
        
    def forward(self, x):
        vgg_out = self.vgg(x)
        age_out = self.age_predictor(vgg_out)
        
        return age_out

## Setup the datasets

In [27]:
# Define training parameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_size = 0.7
validation_size = 0.2
test_size = 0.1
batch_size = 4
epochs = 25
save_model = True

# Get training data
train_data = ImageDataset(
    labels_file, 
    TRAIN_IMGS,
    transform=ToPILImage()
)

# Get indices for train-validation-test split
all_indices = list(range(len(train_data)))
train_indices, test_indices = train_test_split(
    all_indices, 
    test_size=test_size, 
    random_state=42,
)
train_indices, val_indices = train_test_split(
    train_indices, 
    test_size=validation_size/(1-test_size), 
    random_state=42,
)

# Define datasets for each partition
train_dataset = Subset(train_data, train_indices)
val_dataset = Subset(train_data, val_indices)
test_dataset = Subset(train_data, test_indices)

train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=1, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=1, shuffle=False)

## Train the model

In [28]:
loss_fn = nn.MSELoss()
model = Model().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0003)
for epoch in range(1, epochs+1):
    run_result = {'nsamples': 0, 'loss': 0}
    
    for p in model.parameters():
        if p.grad is not None:
            del p.grad  # free some memory
    torch.cuda.empty_cache()
    
    train_bar = tqdm(train_loader)
    for data, target in train_bar:
        batch_size = data.size(0)
        run_result['nsamples'] += batch_size
        
        label = target.to(device)
        z = data.to(device)
        pred_age = model(z.float())

        ######### Train generator #########
        label = label.unsqueeze(1)
        label = label.float()
        model.zero_grad()
        loss = loss_fn(pred_age, label)
        loss.backward()
        optimizer.step()

  0%|          | 0/10557 [00:00<?, ?it/s]

tensor([[0.5638],
        [0.5481],
        [0.5412],
        [0.5449]], grad_fn=<SigmoidBackward0>)


  0%|          | 2/10557 [00:11<15:25:13,  5.26s/it]

tensor([[0.6477],
        [0.6727],
        [0.6621],
        [0.6408]], grad_fn=<SigmoidBackward0>)


  0%|          | 3/10557 [00:14<13:22:01,  4.56s/it]

tensor([[0.7736],
        [0.8317],
        [0.8700],
        [0.7277]], grad_fn=<SigmoidBackward0>)


  0%|          | 4/10557 [00:18<12:17:38,  4.19s/it]

tensor([[0.8540],
        [0.8798],
        [0.7999],
        [0.9244]], grad_fn=<SigmoidBackward0>)


  0%|          | 5/10557 [00:22<12:04:50,  4.12s/it]

tensor([[0.8772],
        [0.9865],
        [0.9817],
        [0.6696]], grad_fn=<SigmoidBackward0>)


  0%|          | 6/10557 [00:26<11:37:07,  3.96s/it]

tensor([[0.7144],
        [0.9824],
        [0.9915],
        [0.9330]], grad_fn=<SigmoidBackward0>)


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x11bef7380>
Traceback (most recent call last):
  File "/Users/parkerhicks/opt/anaconda3/envs/pytorch/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/Users/parkerhicks/opt/anaconda3/envs/pytorch/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1443, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Users/parkerhicks/opt/anaconda3/envs/pytorch/lib/python3.12/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/parkerhicks/opt/anaconda3/envs/pytorch/lib/python3.12/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/parkerhicks/opt/anaconda3/envs/pytorch/lib/python3.12/multiprocessing/connection.py", line 1135, in wait


KeyboardInterrupt: 