## Train Convolutional PCA encoder

### Define transforms to use
Online skull defects are simulated using these transforms

In [1]:
from torchvision import transforms
from headrecbaselines.utils.datasetHeads import SkullRandomHole

transforms_tr = transforms_val = transforms.Compose([SkullRandomHole()])

### Declare Datasets

In [2]:
from headrecbaselines.utils.datasetHeads import MeshHeadsDataset

images_train = open('partitions/train.txt', 'r').read().splitlines()
images_val = open('partitions/validation.txt', 'r').read().splitlines()
img_path = '../datasets/cq500mesh'

train_dataset = MeshHeadsDataset(images_train, img_path, transforms_tr)
val_dataset = MeshHeadsDataset(images_val, img_path, transforms_val)

print(f'{len(images_train)} images for train and {len(images_val)} for val.')

238 images for train and 66 for val.


### Train

#### Define hyperparameters

In [3]:
interp_factor = 0.85  # [512, 512, 233] -> [435, 435, 198]
config = {
    'name': 'PCAH',

    'device': "cuda:0",
    'weight_decay' : 1e-5,
    'train_batch_size': 1,
    'val_batch_size': 1,
    'gamma' : 0.9,
    'step_size' : 50,
    'lr' : 1e-4,
    'epochs': 50,

    'latents': 60,
    'input_size': 1024,
    'interpolate': True,
    'h': int(512 * interp_factor),
    'w': int(512 * interp_factor),
    'slices': int(233 * interp_factor),
}

In [4]:
from headrecbaselines.trainer import trainer
from headrecbaselines.models.PCAH import PCAH_Net

model = PCAH_Net(config)
trainer(train_dataset, val_dataset, model, config)

running in cuda:0
Tensorboard folder: trained/PCAH
Training ...
Epoch [1/50]. 
  train batch [238/238]
    train avg reconstruction error: 20.30212558517937
  val batch [66/66]
    val avg reconstruction error: 10.28929435484337
Model Saved MSE (epoch 65)
Epoch [2/50]. 
  train batch [75/238]

KeyboardInterrupt: 