Skip to content

Commit

Permalink
ptq test error correction (#2860)
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Jun 7, 2024
1 parent 2645e21 commit 4f82ea7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
5 changes: 3 additions & 2 deletions tests/py/ts/ptq/test_ptq_dataloader_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch_tensorrt as torchtrt
import torchvision
import torchvision.transforms as transforms
import torch_tensorrt.ts.ptq as PTQ
from torch.nn import functional as F
from torch_tensorrt.ts.logging import *

Expand Down Expand Up @@ -76,11 +77,11 @@ def test_compile_script(self):
self.testing_dataloader = torch.utils.data.DataLoader(
self.testing_dataset, batch_size=1, shuffle=False, num_workers=1
)
self.calibrator = torchtrt.ptq.DataLoaderCalibrator(
self.calibrator = PTQ.DataLoaderCalibrator(
self.testing_dataloader,
cache_file="./calibration.cache",
use_cache=False,
algo_type=torchtrt.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
algo_type=PTQ.CalibrationAlgo.ENTROPY_CALIBRATION_2,
device=torch.device("cuda:0"),
)

Expand Down
5 changes: 3 additions & 2 deletions tests/py/ts/ptq/test_ptq_to_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch_tensorrt as torchtrt
import torchvision
import torchvision.transforms as transforms
import torch_tensorrt.ts.ptq as PTQ
from torch.nn import functional as F
from torch_tensorrt.ts.logging import *

Expand Down Expand Up @@ -76,11 +77,11 @@ def test_compile_script(self):
self.testing_dataloader = torch.utils.data.DataLoader(
self.testing_dataset, batch_size=1, shuffle=False, num_workers=1
)
self.calibrator = torchtrt.ptq.DataLoaderCalibrator(
self.calibrator = PTQ.DataLoaderCalibrator(
self.testing_dataloader,
cache_file="./calibration.cache",
use_cache=False,
algo_type=torchtrt.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
algo_type=PTQ.CalibrationAlgo.ENTROPY_CALIBRATION_2,
device=torch.device("cuda:0"),
)

Expand Down

0 comments on commit 4f82ea7

Please sign in to comment.