In [1]:
from tqdm import tqdm
import torch
from omegaconf import OmegaConf

from models import get_model
from trainers.metrics import get_metrics
from trainers.optimizer import get_optimizer
from trainers.criterion import get_criterion
from data.dataloader import DataLoader
from data.dataset import get_dataset
from data.sampler import get_sampler

In [2]:
# Load setting
cfg = OmegaConf.load('/workspace/app/app.yaml')
print(OmegaConf.to_yaml(cfg))

experiment:
  name: App
data:
  dataset:
    name: cifar10
    rootdir: /workspace/datasets
    num_train_samples: 40000
    in_channel: 3
    num_class: 10
    classes:
    - plane
    - car
    - bird
    - cat
    - deer
    - dog
    - frog
    - horse
    - ship
    - truck
  sampler:
    name: balanced_batch_sampler
train:
  batch_size: 10
  epochs: 1
  save_best_ckpt: true
  num_workers: 2
  ckpt_path: best_ckpt.pth
  eval: false
  optimizer:
    name: adam
    lr: 0.0001
    decay: 0.0001
  trainer:
    name: default
  criterion:
    name: cross_entropy
  metrics:
    name: classification
model:
  name: simple_cnn
  pretrained: false
  initial_ckpt: null



In [3]:
# Load model
model = get_model(cfg)

In [4]:
# Load metrics, optimizer and criterion
metrics = get_metrics(cfg)
optimizer = get_optimizer(cfg, model.network)
criterion = get_criterion(cfg)

In [5]:
# Load dataset
mode = "trainval"
dataset = get_dataset(cfg, mode)
sampler = get_sampler(cfg, mode, dataset)
train_dataloader = DataLoader(cfg, dataset=dataset.train, sampler=sampler.train)
val_dataloader = DataLoader(cfg, dataset=dataset.val, sampler=sampler.val)

Files already downloaded and verified


In [6]:
# Ecaluation function
def eval(eval_dataloader: object = None, epoch: int = 0) -> float:
        """Evaluation

        Evaluates model.

        Args:
            eval_dataloader: Dataloader.
            epoch: Number of epoch.

        Returns:
            model_score: Indicator of the excellence of model. The higher the value, the better.

        """

        model.network.eval()

        with torch.no_grad():
            with tqdm(eval_dataloader, ncols=100) as pbar:
                for idx, (inputs, targets) in enumerate(pbar):
                    inputs = inputs.to(model.device)
                    targets = targets.to(model.device)

                    outputs = model.network(inputs)

                    loss = criterion(outputs, targets)
                    optimizer.zero_grad()

                    metrics.batch_update(outputs=outputs.cpu().detach().clone(),
                                        targets=targets.cpu().detach().clone(),
                                        loss=loss.item())

                    pbar.set_description(f'eval epoch: {epoch}')
        
        metrics.epoch_update(epoch, mode='eval')

        return metrics.model_score

In [7]:
# Save model's check point
def save_ckpt(epoch: int) -> None:
        """Save checkpoint

        Saves checkpoint.

        Args:
            epoch: Number of epoch.

        """

        ckpt_path = cfg.train.ckpt_path

        torch.save({
            'epoch': epoch,
            'model': model.network,
            'model_state_dict': model.network.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, ckpt_path)

# Train function
def train() -> None:
        """Train

        Trains model.

        """

        epochs = range(cfg.train.epochs)

        for epoch in epochs:
            print(f"==================== Epoch: {epoch} ====================")
            print(f"Train:")
            model.network.train()

            with tqdm(train_dataloader, ncols=100) as pbar:
                for idx, (inputs, targets) in enumerate(pbar):
                    inputs = inputs.to(model.device)
                    targets = targets.to(model.device)
                    outputs = model.network(inputs)

                    loss = criterion(outputs, targets)

                    loss.backward()

                    optimizer.step()
                    optimizer.zero_grad()

                    metrics.batch_update(outputs=outputs.cpu().detach().clone(),
                                            targets=targets.cpu().detach().clone(),
                                            loss=loss.item())

                    pbar.set_description(f'train epoch:{epoch}')

            metrics.epoch_update(epoch, mode='train')
            eval(eval_dataloader=val_dataloader, epoch=epoch)

            if metrics.judge_update_ckpt:
                save_ckpt(epoch=epoch)
                print("Saved the check point.")

        print("Successfully trained the model.")


In [8]:
# Train model
train()

Train:


train epoch:0: 100%|████████████████████████████████████████████| 4000/4000 [00:47<00:00, 84.81it/s]
eval epoch: 0: 100%|███████████████████████████████████████████| 1000/1000 [00:05<00:00, 189.41it/s]


Saved the check point.
Successfully trained the model.


In [9]:
# Test model
mode = "test"
dataset = get_dataset(cfg, mode)
sampler = get_sampler(cfg, mode, dataset)
test_dataloader = DataLoader(cfg, dataset=dataset.test, sampler=sampler.test)
eval(eval_dataloader=test_dataloader)

Files already downloaded and verified


eval epoch: 0: 100%|███████████████████████████████████████████| 1000/1000 [00:05<00:00, 189.70it/s]


43.540000915527344