generated from fastai/nbdev_template
/
cyclegan.py
94 lines (82 loc) · 3.71 KB
/
cyclegan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/04_inference.cyclegan.ipynb (unless otherwise specified).
__all__ = ['FolderDataset', 'load_dataset', 'get_preds_cyclegan', 'export_generator']
# Cell
from ..models.cyclegan import *
from ..train.cyclegan import *
from ..data.unpaired import *
from fastai.vision.all import *
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
import glob
from fastprogress.fastprogress import progress_bar
import os
import PIL
# Cell
class FolderDataset(Dataset):
"""
A PyTorch Dataset class that can be created from a folder `path` of images, for the sole purpose of inference. Optional `transforms`
can be provided.
Attributes: \n
`self.files`: A list of the filenames in the folder. \n
`self.totensor`: `torchvision.transforms.ToTensor` transform. \n
`self.transform`: The transforms passed in as `transforms` to the constructor.
"""
def __init__(self, path,transforms=None):
"""Constructor for this PyTorch Dataset, need to pass the `path`"""
self.files = glob.glob(path+'/*')
self.totensor = torchvision.transforms.ToTensor()
if transforms:
self.transform = torchvision.transforms.Compose(transforms)
else:
self.transform = lambda x: x
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
image = PIL.Image.open(self.files[idx % len(self.files)])
image = self.totensor(image)
image = self.transform(image)
return self.files[idx], image
# Cell
def load_dataset(test_path,bs=4,num_workers=4):
"A helper function for getting a DataLoader for images in the folder `test_path`, with batch size `bs`, and number of workers `num_workers`"
dataset = FolderDataset(
path=test_path,
transforms=[torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=bs,
num_workers=num_workers,
shuffle=True
)
return loader
# Cell
def get_preds_cyclegan(learn,test_path,pred_path,convert_to='B',bs=4,num_workers=4,device='cuda',suffix='tif'):
"""
A prediction function that takes the Learner object `learn` with the trained model, the `test_path` folder with the images to perform
batch inference on, and the output folder `pred_path` where the predictions will be saved. The function will convert images to the domain
specified by `convert_to` (default is 'B'). The other arguments are the batch size `bs` (default=4), `num_workers` (default=4), the `device`
to run inference on (default='cuda') and suffix of the prediction images `suffix` (default='tif').
"""
assert os.path.exists(test_path)
if not os.path.exists(pred_path):
os.mkdir(pred_path)
test_dl = load_dataset(test_path,bs,num_workers)
if convert_to=='B': model = learn.model.G_B.to(device)
else: model = learn.model.G_A.to(device)
for i, xb in progress_bar(enumerate(test_dl),total=len(test_dl)):
fn, im = xb
preds = (model(im.to(device))/2 + 0.5)
for i in range(len(fn)):
new_fn = os.path.join(pred_path,'.'.join([os.path.basename(fn[i]).split('.')[0]+f'_fake{convert_to}',suffix]))
torchvision.utils.save_image(preds[i],new_fn)
# Cell
def export_generator(learn, generator_name='generator',path=Path('.'),convert_to='B'):
if convert_to=='B':
model = learn.model.G_B
elif convert_to=='A':
model = learn.model.G_A
else:
raise ValueError("convert_to must be 'A' or 'B' (generator that converts either from A to B or B to A)")
torch.save(model.state_dict(),path/(generator_name+'.pth'))