In [None]:
!pip install -U catalyst

Requirement already up-to-date: catalyst in /usr/local/lib/python3.7/dist-packages (21.3)


In [None]:
import os
from torch import nn, optim
from torch.utils.data import DataLoader
from catalyst import dl, utils
from catalyst.data.transforms import ToTensor
from catalyst.contrib.datasets import MNIST

In [None]:
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.02)

In [None]:
loaders = {
    "train": DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32),
    "valid": DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32),
}

In [None]:
runner = dl.SupervisedRunner(input_key="features", output_key="logits", target_key="targets", loss_key="loss")

In [None]:
# model training
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=loaders,
    num_epochs=1,
    callbacks=[
        dl.AccuracyCallback(input_key="logits", target_key="targets", topk_args=(1, 3, 5)),
        # catalyst[ml] required
        dl.ConfusionMatrixCallback(input_key="logits", target_key="targets", num_classes=10),
    ],
    logdir="./logs",
    valid_loader="valid",
    valid_metric="loss",
    minimize_valid_metric=True,
    verbose=True,
    load_best_on_end=True,
)

Hparams (experiment): {}


HBox(children=(FloatProgress(value=0.0, description='1/1 * Epoch (train)', max=1875.0, style=ProgressStyle(des…

  for k, v in self.batch_metrics.items()



train (1/1) accuracy: 0.8802833557128906 | accuracy/std: 0.07081212792674185 | accuracy01: 0.8802833557128906 | accuracy01/std: 0.07081212792674185 | accuracy03: 0.9749666452407837 | accuracy03/std: 0.03580065525942475 | accuracy05: 0.9921166896820068 | accuracy05/std: 0.02176691186632642 | loss: 0.5139051675796509 | loss/std: 0.3664878010749817 | lr: 0.02 | momentum: 0.9


  for k, v in metrics.items()


HBox(children=(FloatProgress(value=0.0, description='1/1 * Epoch (valid)', max=313.0, style=ProgressStyle(desc…


valid (1/1) accuracy: 0.8496999740600586 | accuracy/std: 0.08438007466043124 | accuracy01: 0.8496999740600586 | accuracy01/std: 0.08438007466043124 | accuracy03: 0.9573000073432922 | accuracy03/std: 0.04357493405029723 | accuracy05: 0.9886000156402588 | accuracy05/std: 0.019559607813700718 | loss: 0.8412032127380371 | loss/std: 0.5866137742996216 | lr: 0.02 | momentum: 0.9
* Epoch (1/1) 
Top best models:
logs/checkpoints/train.1.pth	0.8412


In [None]:
features_batch = next(iter(loaders["valid"]))[0]
# model stochastic weight averaging
model.load_state_dict(utils.get_averaged_weights_by_path_mask(logdir="./logs", path_mask="*.pth"))
# model tracing
utils.trace_model(model=runner.model, batch=features_batch)
# model quantization
utils.quantize_model(model=runner.model)
# model pruning
utils.prune_model(model=runner.model, pruning_fn="l1_unstructured", amount=0.8)
# onnx export
utils.onnx_export(model=runner.model, batch=features_batch, file="./logs/mnist.onnx", verbose=True)



graph(%0 : Float(32, 1, 28, 28, strides=[784, 784, 28, 1], requires_grad=0, device=cpu),
      %1.bias : Float(10, strides=[1], requires_grad=1, device=cpu),
      %1.weight_orig : Float(10, 784, strides=[784, 1], requires_grad=1, device=cpu),
      %1.weight_mask : Float(10, 784, strides=[784, 1], requires_grad=0, device=cpu)):
  %4 : Float(32, 784, strides=[784, 1], requires_grad=0, device=cpu) = onnx::Flatten[axis=1](%0) # /usr/local/lib/python3.7/dist-packages/torch/nn/modules/flatten.py:40:0
  %5 : Float(10, 784, strides=[784, 1], requires_grad=0, device=cpu) = onnx::Cast[to=1](%1.weight_mask) # /usr/local/lib/python3.7/dist-packages/torch/nn/utils/prune.py:74:0
  %6 : Float(10, 784, strides=[784, 1], requires_grad=1, device=cpu) = onnx::Mul(%5, %1.weight_orig) # /usr/local/lib/python3.7/dist-packages/torch/nn/utils/prune.py:74:0
  %7 : Float(32, 10, strides=[10, 1], requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1](%4, %6, %1.bias) # /usr/local/lib/python3.7/