# Model Optimization Seminar

Сравнение различных методов оптимизации моделей: TorchScript, ONNX, Pruning, Graph Optimization


In [1]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
import psutil
import onnx
import onnxruntime as ort
from torch.utils.data import DataLoader
from torchmetrics.classification import Accuracy, F1Score
import sys
sys.path.append('../../lesson5/seminar')
from text_features_data import TextFeaturesDataModule
import pytorch_lightning as pl
from pytorch_lightning import Trainer


## 1. Загрузка датасета

Загружаем датасет AG News из модуля text_features_data.py. Датасет содержит TF-IDF векторы и статистические фичи для классификации новостных текстов.


In [2]:
dm = TextFeaturesDataModule(
    max_features=20000,
    batch_size=128,
    use_bigrams=True,
    cache_dir="./cache",
    num_workers=0,
    pin_memory=False,
)

dm.prepare_data()
dm.setup()

print(f"Input dim: {dm.input_dim}")
print(f"Num classes: {dm.n_classes}")
print(f"Train samples: {len(dm.train_dataset)}")
print(f"Val samples: {len(dm.val_dataset)}")


Preparing AG News dataset...


Loading train: 100%|██████████| 120000/120000 [00:01<00:00, 86933.57it/s]
Loading test: 100%|██████████| 7600/7600 [00:00<00:00, 83117.46it/s]


Total texts: 127600
Split: train=108460, val=19140
Extracting TF-IDF features (max_features=20000, ngrams=(1, 2))...


TF-IDF train: 100%|██████████| 108460/108460 [00:03<00:00, 27188.23it/s]
TF-IDF val: 100%|██████████| 19140/19140 [00:00<00:00, 29234.53it/s]


TF-IDF dimension: 20000
Extracting statistical features...


Statistical features: 100%|██████████| 108460/108460 [00:03<00:00, 29154.09it/s]
Statistical features: 100%|██████████| 19140/19140 [00:00<00:00, 31554.86it/s]


Concatenating train features (108460x20000 + 5)...
Concatenating val features (19140x20000 + 5)...
Total feature dimension: 20005
Normalizing statistical features...
Normalization done
Saving to cache...
Saving train X (108460x20005, 16553.8 MB)...
Saving train y...
Saving val X (19140x20005, 2921.3 MB)...
Saving val y...
Saved to cache
Dataset ready: classes=4, features=20005, train=108460, val=19140
Input dim: 20005
Num classes: 4
Train samples: 108460
Val samples: 19140


## 2. Построение Multi-Branch модели

Создаем модель с multi-branch архитектурой. Модель содержит три параллельные ветки: Bottleneck, Inverted Bottleneck и Regular блоки. Результаты веток объединяются через конкатенацию.


In [8]:
class BottleneckBlock(nn.Module):
    def __init__(self, dim, activation='gelu', dropout=0.0):
        super().__init__()
        self.dim = dim
        self.bottleneck_dim = max(dim // 4, 1)
        self.activation = nn.GELU() if activation == 'gelu' else nn.ReLU()
        self.dropout = nn.Dropout(dropout) if dropout > 0 else None
        self.fc1 = nn.Linear(self.dim, self.bottleneck_dim)
        self.fc2 = nn.Linear(self.bottleneck_dim, self.dim)
    
    def forward(self, x):
        identity = x
        out = self.fc1(x)
        out = self.activation(out)
        if self.dropout is not None:
            out = self.dropout(out)
        out = self.fc2(out)
        return out + identity

class InvertedBottleneckBlock(nn.Module):
    def __init__(self, dim, expansion_factor=4, activation='gelu', dropout=0.0):
        super().__init__()
        self.dim = dim
        self.expanded_dim = dim * expansion_factor
        self.activation = nn.GELU() if activation == 'gelu' else nn.ReLU()
        self.dropout = nn.Dropout(dropout) if dropout > 0 else None
        self.fc1 = nn.Linear(self.dim, self.expanded_dim)
        self.fc2 = nn.Linear(self.expanded_dim, self.dim)
    
    def forward(self, x):
        identity = x
        out = self.fc1(x)
        out = self.activation(out)
        if self.dropout is not None:
            out = self.dropout(out)
        out = self.fc2(out)
        return out + identity

class RegularBlock(nn.Module):
    def __init__(self, dim, hidden_dim=None, activation='gelu', dropout=0.0):
        super().__init__()
        self.dim = dim
        self.hidden_dim = hidden_dim if hidden_dim else dim * 2
        self.activation = nn.GELU() if activation == 'gelu' else nn.ReLU()
        self.dropout = nn.Dropout(dropout) if dropout > 0 else None
        self.fc1 = nn.Linear(self.dim, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, self.dim)
    
    def forward(self, x):
        identity = x
        out = self.fc1(x)
        out = self.activation(out)
        if self.dropout is not None:
            out = self.dropout(out)
        out = self.fc2(out)
        return out + identity

class MultiBranchMLP(nn.Module):
    def __init__(
        self,
        input_dim,
        hidden_dim,
        output_dim,
        num_blocks=4,
        dropout=0.1,
        combine_mode='concat'
    ):
        super().__init__()
        self.output_dim = output_dim
        self.combine_mode = combine_mode
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        
        self.bottleneck_branch = nn.ModuleList([
            BottleneckBlock(hidden_dim, dropout=dropout) for _ in range(num_blocks)
        ])
        self.inverted_branch = nn.ModuleList([
            InvertedBottleneckBlock(hidden_dim, dropout=dropout) for _ in range(num_blocks)
        ])
        self.regular_branch = nn.ModuleList([
            RegularBlock(hidden_dim, dropout=dropout) for _ in range(num_blocks)
        ])
        
        if combine_mode == 'concat':
            self.output_proj = nn.Linear(hidden_dim * 3, output_dim)
        else:
            self.output_proj = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = self.input_proj(x)
        
        bottleneck_out = x
        for block in self.bottleneck_branch:
            bottleneck_out = block(bottleneck_out)
        
        inverted_out = x
        for block in self.inverted_branch:
            inverted_out = block(inverted_out)
        
        regular_out = x
        for block in self.regular_branch:
            regular_out = block(regular_out)
        
        if self.combine_mode == 'concat':
            combined = torch.cat([bottleneck_out, inverted_out, regular_out], dim=1)
        else:
            combined = bottleneck_out + inverted_out + regular_out
        
        out = self.output_proj(combined)
        return out

model = MultiBranchMLP(
    input_dim=dm.input_dim,
    hidden_dim=256,
    output_dim=dm.n_classes,
    num_blocks=4,
    dropout=0.1,
    combine_mode='concat'
)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")


Model parameters: 8,410,884


## 3. Обучение модели

Обучаем модель с использованием PyTorch Lightning. Используем стандартные настройки оптимизатора и функцию потерь для многоклассовой классификации.


In [9]:
class TextClassificationModule(pl.LightningModule):
    def __init__(self, model, lr=1e-3, num_classes=4):
        super().__init__()
        self.model = model
        self.lr = lr
        self.criterion = nn.CrossEntropyLoss()
        self.train_acc = Accuracy(task='multiclass', num_classes=num_classes)
        self.val_acc = Accuracy(task='multiclass', num_classes=num_classes)
        self.val_f1 = F1Score(task='multiclass', num_classes=num_classes, average='macro')
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        self.train_acc(logits, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        self.val_acc(logits, y)
        self.val_f1(logits, y)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_f1', self.val_f1, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

lightning_model = TextClassificationModule(model, lr=1e-3, num_classes=dm.n_classes)

trainer = Trainer(
    max_epochs=5,
    enable_checkpointing=False,
    logger=False,
    enable_progress_bar=True,
    enable_model_summary=False,
    accelerator='gpu',
    devices=1
)

trainer.fit(lightning_model, dm)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Loading from cache...
Dataset loaded: classes=4, features=20005, train=108460, val=19140


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


## 4. Измерение метрик и производительности базовой модели

Измеряем точность, F1-score, скорость инференса и потребление памяти CPU для базовой PyTorch модели. Используем валидационный датасет для оценки метрик и замеряем время выполнения на CPU.


In [10]:
def measure_model_performance(model, dataloader, device='cpu'):
    model.eval()
    model.to(device)
    
    process = psutil.Process()
    memory_before = process.memory_info().rss / 1024 / 1024
    
    y_true_all = []
    y_pred_all = []
    
    start_time = time.time()
    
    with torch.no_grad():
        for batch in dataloader:
            x, y = batch
            x = x.to(device)
            y = y.to(device)
            logits = model(x)
            preds = torch.argmax(logits, dim=1)
            y_true_all.append(y.cpu().numpy())
            y_pred_all.append(preds.cpu().numpy())
    
    inference_time = time.time() - start_time
    
    memory_after = process.memory_info().rss / 1024 / 1024
    memory_used = memory_after - memory_before
    
    y_true = np.concatenate(y_true_all)
    y_pred = np.concatenate(y_pred_all)
    
    accuracy_metric = Accuracy(task='multiclass', num_classes=dm.n_classes)
    f1_metric = F1Score(task='multiclass', num_classes=dm.n_classes, average='macro')
    
    accuracy = accuracy_metric(torch.LongTensor(y_pred), torch.LongTensor(y_true)).item()
    f1 = f1_metric(torch.LongTensor(y_pred), torch.LongTensor(y_true)).item()
    
    num_samples = len(y_true)
    throughput = num_samples / inference_time
    
    return {
        'accuracy': accuracy,
        'f1': f1,
        'inference_time': inference_time,
        'throughput': throughput,
        'memory_mb': memory_used
    }

val_loader = dm.val_dataloader()
baseline_results = measure_model_performance(model, val_loader, device='cpu')

print(f"Baseline Model Results:")
print(f"  Accuracy: {baseline_results['accuracy']:.4f}")
print(f"  F1 Score: {baseline_results['f1']:.4f}")
print(f"  Inference Time: {baseline_results['inference_time']:.4f} s")
print(f"  Throughput: {baseline_results['throughput']:.2f} samples/s")
print(f"  Memory Usage: {baseline_results['memory_mb']:.2f} MB")


Baseline Model Results:
  Accuracy: 0.9031
  F1 Score: 0.9030
  Inference Time: 2.8081 s
  Throughput: 6816.09 samples/s
  Memory Usage: 0.00 MB


## 5. Конвертация в TorchScript

Конвертируем модель в TorchScript формат используя torch.jit.script. TorchScript позволяет выполнять модели независимо от Python интерпретатора и может применять различные оптимизации графа вычислений.


In [11]:
model.eval()
example_input = torch.randn(1, dm.input_dim)

scripted_model = torch.jit.script(model)

torchscript_results = measure_model_performance(scripted_model, val_loader, device='cpu')

print(f"TorchScript Model Results:")
print(f"  Accuracy: {torchscript_results['accuracy']:.4f}")
print(f"  F1 Score: {torchscript_results['f1']:.4f}")
print(f"  Inference Time: {torchscript_results['inference_time']:.4f} s")
print(f"  Throughput: {torchscript_results['throughput']:.2f} samples/s")
print(f"  Memory Usage: {torchscript_results['memory_mb']:.2f} MB")


TorchScript Model Results:
  Accuracy: 0.9031
  F1 Score: 0.9030
  Inference Time: 2.0786 s
  Throughput: 9208.04 samples/s
  Memory Usage: 0.00 MB


## 6. Конвертация в ONNX

Экспортируем модель в формат ONNX используя torch.onnx.export. ONNX является открытым стандартом для представления моделей машинного обучения и позволяет запускать модели в различных runtime окружениях, включая ONNX Runtime.


In [12]:
model.eval()
example_input = torch.randn(1, dm.input_dim)

onnx_path = "model.onnx"
torch.onnx.export(
    model,
    example_input,
    onnx_path,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
    opset_version=11
)

onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

ort_session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])

def measure_onnx_performance(ort_session, dataloader):
    process = psutil.Process()
    memory_before = process.memory_info().rss / 1024 / 1024
    
    y_true_all = []
    y_pred_all = []
    
    start_time = time.time()
    
    for batch in dataloader:
        x, y = batch
        x_np = x.numpy()
        
        outputs = ort_session.run(None, {'input': x_np})
        logits = torch.from_numpy(outputs[0])
        preds = torch.argmax(logits, dim=1)
        
        y_true_all.append(y.numpy())
        y_pred_all.append(preds.numpy())
    
    inference_time = time.time() - start_time
    
    memory_after = process.memory_info().rss / 1024 / 1024
    memory_used = memory_after - memory_before
    
    y_true = np.concatenate(y_true_all)
    y_pred = np.concatenate(y_pred_all)
    
    accuracy_metric = Accuracy(task='multiclass', num_classes=dm.n_classes)
    f1_metric = F1Score(task='multiclass', num_classes=dm.n_classes, average='macro')
    
    accuracy = accuracy_metric(torch.LongTensor(y_pred), torch.LongTensor(y_true)).item()
    f1 = f1_metric(torch.LongTensor(y_pred), torch.LongTensor(y_true)).item()
    
    num_samples = len(y_true)
    throughput = num_samples / inference_time
    
    return {
        'accuracy': accuracy,
        'f1': f1,
        'inference_time': inference_time,
        'throughput': throughput,
        'memory_mb': memory_used
    }

onnx_results = measure_onnx_performance(ort_session, val_loader)

print(f"ONNX Model Results:")
print(f"  Accuracy: {onnx_results['accuracy']:.4f}")
print(f"  F1 Score: {onnx_results['f1']:.4f}")
print(f"  Inference Time: {onnx_results['inference_time']:.4f} s")
print(f"  Throughput: {onnx_results['throughput']:.2f} samples/s")
print(f"  Memory Usage: {onnx_results['memory_mb']:.2f} MB")


ONNX Model Results:
  Accuracy: 0.9031
  F1 Score: 0.9030
  Inference Time: 2.0176 s
  Throughput: 9486.61 samples/s
  Memory Usage: 10.05 MB


## 7. Pruning модели

Применяем magnitude-based pruning к модели. Pruning удаляет наименее важные веса на основе их абсолютных значений, что позволяет уменьшить размер модели и ускорить инференс при сохранении приемлемой точности.


In [13]:
import torch.nn.utils.prune as prune

pruned_model = MultiBranchMLP(
    input_dim=dm.input_dim,
    hidden_dim=256,
    output_dim=dm.n_classes,
    num_blocks=4,
    dropout=0.1,
    combine_mode='concat'
)
pruned_model.load_state_dict(model.state_dict())

parameters_to_prune = []
for name, module in pruned_model.named_modules():
    if isinstance(module, nn.Linear):
        parameters_to_prune.append((module, 'weight'))

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.3,
)

for module, name in parameters_to_prune:
    prune.remove(module, name)

pruned_results = measure_model_performance(pruned_model, val_loader, device='cpu')

print(f"Pruned Model Results:")
print(f"  Accuracy: {pruned_results['accuracy']:.4f}")
print(f"  F1 Score: {pruned_results['f1']:.4f}")
print(f"  Inference Time: {pruned_results['inference_time']:.4f} s")
print(f"  Throughput: {pruned_results['throughput']:.2f} samples/s")
print(f"  Memory Usage: {pruned_results['memory_mb']:.2f} MB")

pruned_params = sum(p.numel() for p in pruned_model.parameters())
original_params = sum(p.numel() for p in model.parameters())
print(f"  Parameters: {pruned_params:,} (original: {original_params:,}, reduction: {(1 - pruned_params/original_params)*100:.1f}%)")


Pruned Model Results:
  Accuracy: 0.9033
  F1 Score: 0.9032
  Inference Time: 2.4471 s
  Throughput: 7821.48 samples/s
  Memory Usage: 19.85 MB
  Parameters: 8,410,884 (original: 8,410,884, reduction: 0.0%)


## 8. Оптимизация графа вычислений

Применяем оптимизации графа вычислений используя torch.jit.optimize_for_inference. Эта функция выполняет различные оптимизации, такие как слияние операций, удаление неиспользуемых веток и другие преобразования для ускорения инференса.


In [14]:
model.eval()
example_input = torch.randn(1, dm.input_dim)

optimized_model = torch.jit.script(model)
optimized_model = torch.jit.optimize_for_inference(optimized_model)

optimized_results = measure_model_performance(optimized_model, val_loader, device='cpu')

print(f"Optimized Graph Model Results:")
print(f"  Accuracy: {optimized_results['accuracy']:.4f}")
print(f"  F1 Score: {optimized_results['f1']:.4f}")
print(f"  Inference Time: {optimized_results['inference_time']:.4f} s")
print(f"  Throughput: {optimized_results['throughput']:.2f} samples/s")
print(f"  Memory Usage: {optimized_results['memory_mb']:.2f} MB")


Optimized Graph Model Results:
  Accuracy: 0.9031
  F1 Score: 0.9030
  Inference Time: 2.2327 s
  Throughput: 8572.60 samples/s
  Memory Usage: 6.57 MB


## 9. Сравнение результатов

Сравниваем все методы оптимизации по метрикам точности, скорости инференса и потреблению памяти.


In [15]:
results = {
    'Baseline': baseline_results,
    'TorchScript': torchscript_results,
    'ONNX': onnx_results,
    'Pruned': pruned_results,
    'Optimized Graph': optimized_results
}

print("Comparison of all optimization methods:")
print("-" * 80)
print(f"{'Method':<20} {'Accuracy':<12} {'F1 Score':<12} {'Time (s)':<12} {'Throughput':<15} {'Memory (MB)':<15}")
print("-" * 80)
for method, res in results.items():
    print(f"{method:<20} {res['accuracy']:<12.4f} {res['f1']:<12.4f} {res['inference_time']:<12.4f} {res['throughput']:<15.2f} {res['memory_mb']:<15.2f}")
print("-" * 80)


Comparison of all optimization methods:
--------------------------------------------------------------------------------
Method               Accuracy     F1 Score     Time (s)     Throughput      Memory (MB)    
--------------------------------------------------------------------------------
Baseline             0.9031       0.9030       2.8081       6816.09         0.00           
TorchScript          0.9031       0.9030       2.0786       9208.04         0.00           
ONNX                 0.9031       0.9030       2.0176       9486.61         10.05          
Pruned               0.9033       0.9032       2.4471       7821.48         19.85          
Optimized Graph      0.9031       0.9030       2.2327       8572.60         6.57           
--------------------------------------------------------------------------------
