Skip to content

Commit

Permalink
Turn off autocast
Browse files Browse the repository at this point in the history
  • Loading branch information
ptran1203 committed Apr 16, 2024
1 parent 2a5e7ae commit 676f795
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 7 deletions.
11 changes: 11 additions & 0 deletions .dockerignore
@@ -0,0 +1,11 @@
__pycache__/
*.pyc
*.zip
dataset
spirit.away
.DS_Store
.todo
*.pth
*.pt
.cache
inference_images
3 changes: 2 additions & 1 deletion .gitignore
Expand Up @@ -8,4 +8,5 @@ spirit.away
*.pth
*.pt
.cache
inference_images
inference_images
.cog
5 changes: 3 additions & 2 deletions inference.py
Expand Up @@ -10,6 +10,7 @@
from utils.image_processing import resize_image, normalize_input, denormalize_input
from utils import read_image, is_image_file
from tqdm import tqdm
# from torch.cuda.amp import autocast


VALID_FORMATS = {
Expand Down Expand Up @@ -77,9 +78,9 @@ def transform(self, image, denorm=True):
with torch.no_grad():
image = self.preprocess_images(image)
# image = image.to(self.device)
with torch.autocast(self.device_type, enabled=self.amp):
# with autocast(self.device_type, enabled=self.amp):
# print(image.dtype, self.G)
fake = self.G(image)
fake = self.G(image)
fake = fake.detach().cpu().numpy()
# Channel last
fake = fake.transpose(0, 2, 3, 1)
Expand Down
12 changes: 8 additions & 4 deletions predict.py
@@ -1,5 +1,5 @@
from pathlib import Path
from inference import Predictor
from inference import Predictor as MyPredictor
from utils import read_image
import cv2
import tempfile
Expand All @@ -16,13 +16,17 @@ def predict(
self,
image: Path = Input(description="Image"),
model: str = Input(
description="Factor to scale image by",
description="Style",
default='Hayao:v2',
choices=['Hayao', 'Shinkai', 'Hayao:v2']
choices=[
'Hayao',
'Shinkai',
'Hayao:v2'
]
)
) -> Path:
version = model.split(":")[-1]
predictor = Predictor(model, version)
predictor = MyPredictor(model, version)
img = read_image(str(image))
anime_img = predictor.transform(resize_image(img))[0]
# out_path = Path(tempfile.mkdtemp()) / "out.png"
Expand Down

0 comments on commit 676f795

Please sign in to comment.