In [None]:
import functools

import matplotlib.pyplot as plt

from IPython.display import clear_output

import torch.nn.functional as tnnf
import torch
from ignite.engine import create_supervised_trainer, Events

In [None]:
def do_every_num_epochs(num_epochs):
    """This must be written after @trainer.on, not before."""
    def decorate(func):
        def decorated(engine, *args, **kwargs):
            if engine.state.epoch % num_epochs == 0:
                return func(engine, *args, **kwargs)
        return functools.update_wrapper(decorated, func)
    return decorate
                      
def train_and_evaluate(
    X_train, y_train, X_test, y_test,
    model, optimizer,
    eval_every_num_epochs, plot_every_num_epochs,
    num_epochs
):    
    trainer = create_supervised_trainer(
        model=model, optimizer=optimizer,
        loss_fn=tnnf.cross_entropy
    )

    evaluations_epochs = []
    train_log = {
        "losses": [],
        "accuracies": []
    }
    test_log = {
        "losses": [],
        "accuracies": []
    }
    
    def evaluate(X, y, log):
        model.train(False)
        logits = model(X)
        loss = tnnf.cross_entropy(logits, y).item()
        predictions = logits.argmax(dim=1)
        accuracy = (y == predictions).sum().item() / len(y)
        log["losses"].append(loss)
        log["accuracies"].append(accuracy)

    @trainer.on(Events.EPOCH_COMPLETED)
    @do_every_num_epochs(eval_every_num_epochs)
    def evaluate_on_train_and_test(engine):
        evaluate(X_train, y_train, train_log)
        evaluate(X_test, y_test, test_log)
        assert not isinstance(engine.state.epoch, torch.Tensor)
        evaluations_epochs.append(engine.state.epoch)
    
    @trainer.on(Events.EPOCH_COMPLETED)
    @do_every_num_epochs(plot_every_num_epochs)
    def update_plot(engine):
        clear_output(wait=True)
        fig, axes = plt.subplots(ncols=2, figsize=(14, 5))
        axes = axes.flatten()
        axes[0].set_title("Loss")
        axes[0].plot(evaluations_epochs, train_log["losses"], label="train loss")
        axes[0].plot(evaluations_epochs, test_log["losses"], label="test loss")
        axes[0].legend()
        axes[1].set_title(f"Accuracy. Test: {test_log['accuracies'][-1]}")
        axes[1].plot(evaluations_epochs, train_log["accuracies"], label="train accuracy")
        axes[1].plot(evaluations_epochs, test_log["accuracies"], label="test accuracy")
        axes[1].legend()
        plt.show()
        
    trainer.run([(X_train, y_train)], max_epochs=num_epochs)
    return model, evaluations_epochs, train_log, test_log

In [None]:
def memreport():
    print(f"""
    {torch.cuda.memory_allocated()/1024/1024/1024} Gb allocated
    {torch.cuda.memory_cached()/1024/1024/1024} Gb cached
    """)
    for obj in gc.get_objects():
        if torch.is_tensor(obj):
            print(type(obj), obj.size())