In [1]:
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau

import argparse
import h5py
import json
import logging
import random
import time
import torch
import os

from models import SleepPredictionMLP
from models import SleepPredictionCNN
from models import SleepPredictionSeq
from dataloader import get_data_loader

In [2]:
class Dict2Obj(dict):
    """Converts dicts to objects.
    """
    def __getattr__(self, name):
        if name in self:
            return self[name]
        else:
            raise AttributeError("No such attribute: " + name)

    def __setattr__(self, name, value):
        self[name] = value

    def __delattr__(self, name):
        if name in self:
            del self[name]
        else:
            raise AttributeError("No such attribute: " + name)

    def merge(self, other, overwrite=True):
        for name in other:
            if overwrite or name not in self:
                self[name] = other[name]

In [3]:
def main(args):
    # Setting up seeds.
    torch.cuda.manual_seed(args['--seed'])
    torch.manual_seed(args['--seed'])
    
    # Create model directory.
    if not os.path.exists(args['--model-dir']):
        os.makedirs(args['--model-dir'])

    # Config logging.
    log_format = '%(levelname)-8s %(message)s'
    logfile = os.path.join(args['--model-dir'], 'train.log')
    logging.basicConfig(filename=logfile, level=logging.INFO, format=log_format)
    logging.getLogger().addHandler(logging.StreamHandler())
    #logging.info(json.dumps(args))

    # Save the arguments.
    with open(os.path.join(args['--model-dir'], 'args.json'), 'w') as args_file:
        json.dump(args, args_file)

    # Build data loader.
    logging.info("Building data loader...")
    dataset_size = get_dataset_size(args['--dataset'])
    train_set = random.sample(range(dataset_size), int(dataset_size * 0.9))
    data_loader = get_data_loader(args['--dataset'], args['--labels-data'], args['--batch-size'], shuffle=True,
                                  num_workers=args['--num-workers'], model_type=args['--mode'],
                                  indices=train_set)
    val_data_loader = get_data_loader(args['--dataset'], args['--labels-data'], args['--batch-size'], shuffle=False,
                                      num_workers=args['--num-workers'], model_type=args['--mode'],
                                      indices=list(set(range(dataset_size)) - set(train_set)))
    logging.info("Done")

    # Build the models
    logging.info("Building Sleep Stage Predictor...")
    if args['--mode'] == 'mlp':
        model = SleepPredictionMLP(1250, 256, num_classes=args['--output-size'],
                                   num_layers=2, dropout_p=0.0, w_norm=False)
    elif args['--mode'] == 'cnn':
        model = SleepPredictionCNN(output_size=args['--output-size'])
    else:
        model = SleepPredictionSeq(output_size=args['--output-size'])
    pre = 0
    if args['--pretrained']:
        model.load_state_dict(torch.load(args['--model-path']))
        pre = int(args['--model-path'].split('-')[1].split('.')[0])
    logging.info("Done")

    if torch.cuda.is_available():
        model.cuda()

    # Loss and Optimizer.
    criterion = nn.CrossEntropyLoss()
    if torch.cuda.is_available():
        criterion.cuda()        

    # Parameters to train.
    #optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.99, weight_decay = 0.0)
    optimizer = torch.optim.Adam(model.parameters(), lr=args['--learning-rate'], weight_decay=0.001)
    scheduler = ReduceLROnPlateau(optimizer=optimizer, mode='min',
                                  factor=0.1, patience=args['--patience'],
                                  verbose=True, min_lr=1e-6)

    # Train the Models.
    total_steps = len(data_loader) * args['--num-epochs']
    start_time = time.time()
    
    n_steps = 0
    for epoch in range(args['--num-epochs']):
        l = 0.0
        a = 0.0
        start_time = time.time()
        for i, (conditions, eegs, labels) in enumerate(data_loader):
            n_steps += 1
            # Set mini-batch dataset.
            if torch.cuda.is_available():
                eegs = eegs.cuda()
                conditions = conditions.cuda()
                labels = labels.cuda()            
            # Forward.
            model.train()
            model.zero_grad()

            outputs = model(eegs, conditions)

            # Calculate the loss.
            loss = criterion(outputs, labels)

            # Backprop and optimize.
            loss.backward()
            optimizer.step()

            # Eval now.
            #if (n_steps % args.eval_every_n_steps == 0):
            #    run_eval(model, val_data_loader, criterion,
            #             args, epoch, scheduler)
            _, preds = torch.max(outputs.data, 1)
            a += (preds == labels).sum().item()
            l += loss.item()

            if (i+1) % args['--log-step'] == 0:
                logging.info('Time: %.2f Epoch [%d/%d], step [%d/%d], Train loss: %.4f, Train accuracy: %.4f' % (
                    time.time() - start_time, epoch, args['--num-epochs'], i+1, len(data_loader), 
                    loss.item(), (preds == labels).sum().item() / float(args['--batch-size'])))
                start_time = time.time()

            # Save the models.
            #if (i+1) % args.save_step == 0:
            #    torch.save(model.state_dict(),
            #               os.path.join(args.model_dir, 
            #                            'model-%d-%d.pkl' %(epoch+1, i+1)))

        torch.save(model.state_dict(), os.path.join(args['--model-dir'],
                   'model-%d.pkl' % (epoch+1+pre)))

        # Evaluation and learning rate updates.
        logging.info('Epoch [%d/%d], Train loss: %.4f, Train accuracy: %.4f' % (
                    epoch, args['--num-epochs'], l / len(data_loader), a / int(dataset_size * 0.9)))
        run_eval(model, val_data_loader, criterion, args, epoch, scheduler)

    # Save the final model.
    torch.save(model.state_dict(),os.path.join(args['--model-dir'],'model.pkl'))


In [4]:
args_dict = {}
args_dict['--model-dir'] = 'weights/mlp/'
args_dict['--model-path'] = 'weights/resnet/model.pkl'
args_dict['--dataset'] = 'data/processed_train_dataset.hdf5'
args_dict['--labels-data'] = 'data/y_train_2.csv'

# Session parameters.
args_dict['--log-step'] = 10
args_dict['--save-step'] = 1000
args_dict['--eval-steps'] = None
args_dict['--eval-every-n-steps'] = 1000
args_dict['--eval-all'] = True
args_dict['--num-epochs'] = 10
args_dict['--batch-size'] = 16
args_dict['--num-workers'] = 8
args_dict['--learning-rate'] = 0.001
args_dict['--patience'] = 0
args_dict['--data-size'] = 100
args_dict['--seed'] = 1

# Model parameters.
args_dict['--mode'] = 'mlp'
args_dict['--output-size'] = 3
args_dict['--pretrained'] = False

In [5]:
def get_dataset_size(dataset):
    annos = h5py.File(dataset, 'r')
    size = annos['mlp'].shape[0]
    annos.close()
    return size

In [6]:
# Grab all the training parameters.
args = Dict2Obj(args_dict)
type(args)
list(args.keys())

['--num-epochs',
 '--pretrained',
 '--seed',
 '--eval-all',
 '--mode',
 '--data-size',
 '--model-path',
 '--output-size',
 '--log-step',
 '--eval-steps',
 '--dataset',
 '--eval-every-n-steps',
 '--model-dir',
 '--save-step',
 '--learning-rate',
 '--num-workers',
 '--batch-size',
 '--patience',
 '--labels-data']

In [7]:
main(args_dict)

Building data loader...
Done
Building Sleep Stage Predictor...
Done
Time: 9.53 Epoch [0/10], step [10/14717], Train loss: 2.0331, Train accuracy: 0.0000
Time: 4.55 Epoch [0/10], step [20/14717], Train loss: 1.4200, Train accuracy: 0.0000
Time: 4.54 Epoch [0/10], step [30/14717], Train loss: 1.5190, Train accuracy: 0.0000
Time: 4.53 Epoch [0/10], step [40/14717], Train loss: 1.1964, Train accuracy: 0.0000
Time: 8.85 Epoch [0/10], step [50/14717], Train loss: 1.6673, Train accuracy: 0.0000
Time: 4.53 Epoch [0/10], step [60/14717], Train loss: 1.4077, Train accuracy: 0.0000
Time: 4.53 Epoch [0/10], step [70/14717], Train loss: 1.3757, Train accuracy: 0.0000
Time: 4.55 Epoch [0/10], step [80/14717], Train loss: 1.3761, Train accuracy: 0.0000
Time: 8.83 Epoch [0/10], step [90/14717], Train loss: 1.1389, Train accuracy: 0.0000
Time: 4.51 Epoch [0/10], step [100/14717], Train loss: 1.2444, Train accuracy: 0.0000
Time: 4.48 Epoch [0/10], step [110/14717], Train loss: 1.1523, Train accuracy: 0.

KeyboardInterrupt: 