In [None]:
%%capture
! sh pull.sh

In [None]:
cd moscontest

In [None]:
%%capture
! dvc pull data/processed/processed.zip.dvc
! unzip 'data/processed/processed.zip'

In [1]:
import torch
from torchvision import transforms
from catalyst import utils, dl
from src.nn import FERCNN, AUCCallback, PrecisionRecallF1ScoreCallback, ConfusionMatrixCallback
from src.data.features import FER
from src.utils import IMG_SIZE
utils.prepare_cudnn(deterministic=True)
utils.set_global_seed(7)
device = utils.get_device()

In [2]:
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize(IMG_SIZE),
    transforms.RandomPerspective(distortion_scale=0.15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

In [3]:
kwargs = {'batch_size': 700, 'num_workers': 8}
threshold = 35000
dataset = FER(exclude=['contempt'], transform=transform)
train_dataloader = dataset.data_loader(end=threshold, shuffle=True, drop_last=True, **kwargs)
valid_dataloader = dataset.data_loader(start=threshold, **kwargs)
loaders = {'train': train_dataloader, 'valid': valid_dataloader}

In [4]:
class_names = dataset.classes
num_classes = len(class_names)
model = FERCNN(num_classes)
weight = dataset._weight_classes().to(device)
criterion = torch.nn.CrossEntropyLoss(weight=weight)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100], gamma=0.1)
callbacks = [
    AUCCallback(num_classes=num_classes, class_names=class_names),
    PrecisionRecallF1ScoreCallback(num_classes=num_classes, class_names=class_names),
    ConfusionMatrixCallback(num_classes=num_classes, class_names=class_names),
]
logdir = 'logs/FERCNN/'

In [5]:
weight

tensor([2.5099, 9.3922, 2.6786, 1.0000, 1.6625, 1.7839, 2.6140],
       device='cuda:0')

In [None]:
%load_ext tensorboard
%tensorboard --logdir $logdir

In [None]:
runner = dl.SupervisedRunner(device=device)
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    callbacks=callbacks,
    loaders=loaders,
    logdir=logdir,
    num_epochs=200,
    verbose=True,
)

In [None]:
! sh ../push.sh $logdir