## Model Training

Model training stage focuses on using Xception and its improvements. We are testing four models, titled baseline and improvement 1 to 3. Each model's specifications and workflow can be seen in their respective "scenarios", available in `src.nn.scenarios`. 

Due to the computation load of this task, it is advisable to run this on a capable hardware.

### Prelim

In [None]:
%run __init__.py

%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

### Baseline

In [None]:
from torchvision import transforms

from src.nn.scenarios.baseline import XceptionNetBaseline
from src.nn.training import data_loader

In [None]:
pretrained_train_loader, _, pretrained_test_loader = data_loader.load_data(
    path="../data/preprocessed/MTCNN-Celeb-DF-v2",  # change this to whatever dataset required
    transforms=transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor()
    ]),
    train_size=0.8,
    val_size=0.1,
    batch_size=2,  # Should be 32
)

fine_tuning_train_loader, _, fine_tuning_test_loader = data_loader.load_data(
    path="../data/preprocessed/MTCNN-Celeb-DF-v2",
    transforms=transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor()
    ]),
    train_size=0.8,
    val_size=0.1,
    batch_size=2,  # Should be 16
)

model = XceptionNetBaseline()
model.train(
    pretrain_train_loader=pretrained_train_loader,
    pretrain_test_loader=pretrained_test_loader,
    pretraining_epochs=1,  # Should be 3
    fine_tuning_train_loader=fine_tuning_train_loader,
    fine_tuning_test_loader=fine_tuning_test_loader,
    fine_tuning_epochs=1,  # Should be 15
    save_to="../models/Celeb-DF-v2-Split10-XceptionNetBaseline",
)

### Improvement 1

For this model, XceptionNet is used only as feature extractor. Then, a few layers are added at the tail as the classifier.

In [None]:
from torchvision import transforms

from src.nn.scenarios.improvement_1 import XceptionNetImprovement1
from src.nn.training import data_loader

In [None]:
train_loader, _, test_loader = data_loader.load_data(
    path="../data/preprocessed/MTCNN-Celeb-DF-v2",
    transforms=transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor()
    ]),
    train_size=0.8,
    val_size=0.1,
    batch_size=2,  # Should be 32
)

model = XceptionNetImprovement1()
model.train(
    train_loader=train_loader,
    test_loader=test_loader,
    epochs=2,  # Should be 15
    save_to="../models/Celeb-DF-v2-Split10-XceptionNetImprovement1",
)

### Improvement 2

In [None]:
from torchvision import transforms

from src.nn.scenarios.improvement_2 import XceptionNetImprovement2
from src.nn.training import data_loader

In [None]:
pretrained_train_loader, _, pretrained_test_loader = data_loader.load_data(
    path="../data/preprocessed/MTCNN-Celeb-DF-v2",
    transforms=transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor()
    ]),
    train_size=0.8,
    val_size=0.1,
    batch_size=2,  # Should be 32
)

fine_tuning_train_loader, _, fine_tuning_test_loader = data_loader.load_data(
    path="../data/preprocessed/MTCNN-Celeb-DF-v2",
    transforms=transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor()
    ]),
    train_size=0.8,
    val_size=0.1,
    batch_size=2,  # Should be 16
)

model = XceptionNetImprovement2()
model.train(
    pretrain_train_loader=pretrained_train_loader,
    pretrain_test_loader=pretrained_test_loader,
    pretraining_epochs=1,  # Should be 3
    fine_tuning_train_loader=fine_tuning_train_loader,
    fine_tuning_test_loader=fine_tuning_test_loader,
    fine_tuning_epochs=1,  # Should be 15
    save_to="../models/Celeb-DF-v2-Split10-XceptionNetImprovement2",
)

### Improvement 3

In [None]:
from torchvision import transforms

from src.nn.scenarios.improvement_3 import XceptionNetImprovement3
from src.nn.training import data_loader

In [None]:
train_loader, _, test_loader = data_loader.load_data(
    path="../data/preprocessed/MTCNN-Celeb-DF-v2",
    transforms=transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor()
    ]),
    train_size=0.8,
    val_size=0.1,
    batch_size=2,  # Should be 32
)

model = XceptionNetImprovement3()
model.train(
    train_loader=train_loader,
    test_loader=test_loader,
    epochs=2,  # Should be 15
    save_to="../models/Celeb-DF-v2-Split10-XceptionNetImprovement3",
)