Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/int8/training/vgg16/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ Use the exporter script to create a torchscript module you can compile with Torc

### For PTQ
```
python3 export_ckpt.py <path-to-checkpoint>
python3 export.py --ckpt <path-to-checkpoint> --ir torchscript --output vgg.ts
```

The checkpoint file should be from the original training and not quatization aware fine tuning. THe script should produce a file called `trained_vgg16.jit.pt`
* `--ckpt` : The checkpoint file should be from the original training and not quatization aware fine tuning.
* `--ir` : Options include `torchscript` or `exported_program`. The saved module type is determined by this `ir` flag.
* `--output` : Output file name

### For QAT
To export a QAT model, you can run
Expand Down
119 changes: 119 additions & 0 deletions examples/int8/training/vgg16/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import argparse
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import quant_modules
from vgg16 import vgg16


def test(model, dataloader, crit):
"""
Run the model on a dataset and measure accuracy/loss
"""
total = 0
correct = 0
loss = 0.0
class_probs = []
class_preds = []

with torch.no_grad():
for data, labels in dataloader:
data, labels = data.cuda(), labels.cuda(non_blocking=True)
out = model(data)
loss += crit(out, labels)
preds = torch.max(out, 1)[1]
class_probs.append([F.softmax(i, dim=0) for i in out])
class_preds.append(preds)
total += labels.size(0)
correct += (preds == labels).sum().item()

return loss / total, correct / total


def evaluate(model):
"""
Evaluate pre-trained model on CIFAR 10 dataset
"""
testing_dataset = datasets.CIFAR10(
root="./data",
train=False,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
]
),
)

testing_dataloader = torch.utils.data.DataLoader(
testing_dataset, batch_size=32, shuffle=False, num_workers=2
)

crit = torch.nn.CrossEntropyLoss()

test_loss, test_acc = test(model, testing_dataloader, crit)
print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))


def export_model(args):
"""
Evaluate and export the model to Torchscript or exported program
"""
# Define the VGG model
# model = vgg16(num_classes=10, init_weights=False)
model = models.vgg16(weights=None).eval().cuda()
# Load the checkpoint
ckpt = torch.load(args.ckpt)
weights = ckpt["model_state_dict"]
model.load_state_dict(weights)
# Setting eval here causes both JIT and TRT accuracy to tank in LibTorch will follow up with PyTorch Team
# model.eval()
random_inputs = [torch.rand([32, 3, 32, 32]).to("cuda")]
if args.ir == "torchscript":
jit_model = torch.jit.trace(model, random_inputs)
jit_model.eval()
# Evaluating JIT model
evaluate(jit_model)
torch.jit.save(jit_model, args.output)
elif args.ir == "exported_program":
dim_x = torch.export.Dim("dim_x", min=1, max=32)
exp_program = torch.export.export(
model, tuple(random_inputs), dynamic_shapes={"x": {0: dim_x}}
)
evaluate(exp_program)
torch.export.save(exp_program, args.output)
else:
raise ValueError(
f"Invalid IR {args.ir} provided to export the VGG model. Select among torchscript | exported_program"
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Export trained VGG")
parser.add_argument("--ckpt", type=str, help="Path to saved checkpoint")
parser.add_argument(
"--ir",
type=str,
default="torchscript",
help="IR to determine the output type of exported graph",
)
parser.add_argument(
"--output", type=str, default="vgg.ts", help="Path to saved checkpoint"
)
parser.add_argument(
"--qat",
action="store_true",
help="Perform QAT using pytorch-quantization toolkit",
)
args = parser.parse_args()
export_model(args)
86 changes: 0 additions & 86 deletions examples/int8/training/vgg16/export_ckpt.py

This file was deleted.

12 changes: 5 additions & 7 deletions examples/int8/training/vgg16/finetune_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,15 @@
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

from torch.utils.tensorboard import SummaryWriter

import torchvision.models as models
import torchvision.transforms as transforms
from pytorch_quantization import calib
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import quant_modules
from pytorch_quantization.tensor_quant import QuantDescriptor
from pytorch_quantization import calib
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from vgg16 import vgg16

PARSER = argparse.ArgumentParser(
Expand Down Expand Up @@ -231,7 +229,7 @@ def main():

quant_modules.initialize()

model = vgg16(num_classes=num_classes, init_weights=False)
model = vgg16(num_classes=num_classes, init_weights=False).cuda()
model = model.cuda()

crit = nn.CrossEntropyLoss()
Expand Down
10 changes: 4 additions & 6 deletions examples/int8/training/vgg16/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

from vgg16 import vgg16

PARSER = argparse.ArgumentParser(
Expand Down Expand Up @@ -125,8 +124,7 @@ def main():

num_classes = len(classes)

model = vgg16(num_classes=num_classes, init_weights=False)
model = model.cuda()
model = vgg16(num_classes=num_classes, init_weights=False).cuda()

data = iter(training_dataloader)
images, _ = next(data)
Expand Down Expand Up @@ -233,7 +231,7 @@ def test(model, dataloader, crit, epoch):
test_preds = torch.cat(class_preds)
for i in range(len(classes)):
add_pr_curve_tensorboard(i, test_probs, test_preds, epoch)
# print(loss, total, correct, total)

return loss / total, correct / total


Expand Down
103 changes: 0 additions & 103 deletions examples/int8/training/vgg16/test_qat.py

This file was deleted.

Loading