-
Notifications
You must be signed in to change notification settings - Fork 5
/
itemtensor.py
149 lines (130 loc) · 6.1 KB
/
itemtensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/vision.augment.itemtensor.ipynb.
# %% ../../../nbs/vision.augment.itemtensor.ipynb 1
# Contains code from:
# fastai - Apache License 2.0 - Copyright (c) 2023 fast.ai
# %% ../../../nbs/vision.augment.itemtensor.ipynb 3
from __future__ import annotations
import operator
from PIL import Image
from torchvision.transforms.functional import pad as tvpad
from torchvision.transforms.functional import _interpolation_modes_from_int
from torch.nn.functional import interpolate
from fastai.vision.augment import RandomCrop, CropPad, Resize, PadMode, RandomResizedCrop, RatioResize, ResizeMethod, _pad_modes, _get_sz
from fastai.vision.core import TensorImage, TensorMask
from ...imports import *
# %% auto 0
__all__ = ['encodes']
# %% ../../../nbs/vision.augment.itemtensor.ipynb 7
def _resize(
x:TensorImage|TensorMask,
size:int|tuple[int,...],
shape:list[int]|tuple[int,...],
interpolation:str
):
align_corners = False if interpolation in ["bilinear", "bicubic"] else None
return interpolate(x.view(shape), size=size, mode=interpolation, align_corners=align_corners)
# %% ../../../nbs/vision.augment.itemtensor.ipynb 8
@patch
def resize(x:TensorImage|TensorMask,
size:int|tuple[int,int],
interpolation:Enum
):
if len(x.shape)==3:
l, c, h, w = 3, x.shape[0], x.shape[1], x.shape[2]
elif len(x.shape)==2:
l, c, h, w = 2, 1, x.shape[0], x.shape[1]
if len(size)==2:
sh, sw = size[0], size[1]
elif len(size)==1:
sh, sw = size, size
x = _resize(x, size=size, shape=[1,c,h,w], interpolation=interpolation.value)
if l==3:
return x.view(c, sh, sw)
else:
return x.view(sh, sw)
# %% ../../../nbs/vision.augment.itemtensor.ipynb 9
@patch
def _do_crop_pad(x:TensorImage|TensorMask,
sz:fastuple, # Crop/pad size of input
tl:fastuple, # Top-left coordinate of the crop/pad, if `None` center crop
orig_sz:fastuple, # Original size of input
pad_mode:PadMode=PadMode.Zeros, # Fastai padding mode
resize_mode:int=Image.BILINEAR, # Pillow `Image` resize mode
resize_to:tuple[int,int]|None=None # Post crop/pad resize of input
):
# PyTorch and PIL axis are opposite, need to reverse PIL axis input for crop and resize
if any(tl.ge(0)) or any(tl.add(sz).le(orig_sz)):
# At least one dim is inside the image, so needs to be cropped
c = tl.max(0)
left, top, right, bottom = *c, *tl.add(sz).min(orig_sz)
x = x[..., top:bottom, left:right]
if any(tl.lt(0)) or any(tl.add(sz).ge(orig_sz)):
# At least one dim is outside the image, so needs to be padded
p = (-tl).max(0)
f = (sz-orig_sz).add(tl).max(0)
if len(x.shape)==2:
x = x.view(1, x.shape[0], x.shape[1])
x = tvpad(x, (*p, *f), padding_mode=_pad_modes[pad_mode])
x = x.view(x.shape[1], x.shape[2])
else:
x = tvpad(x, (*p, *f), padding_mode=_pad_modes[pad_mode])
if resize_to is not None:
resize_mode = Image.NEAREST if isinstance(x,TensorMask) else resize_mode
x = x.resize([*resize_to][::-1], _interpolation_modes_from_int(resize_mode))
return x
# %% ../../../nbs/vision.augment.itemtensor.ipynb 10
@patch
def crop_pad(x:TensorImage|TensorMask,
sz:int|tuple[int,int], # Crop/pad size of input, duplicated if one value is specified
tl:tuple[int,int]|None=None, # Top-left coordinate of the crop/pad, if `None` center crop
orig_sz:tuple[int,int]|None=None, # Original size of input
pad_mode:PadMode=PadMode.Zeros, # Fastai padding mode
resize_mode:int=Image.BILINEAR, # Pillow `Image` resize mode
resize_to:tuple[int,int]|None=None # Post crop/pad resize of input
):
if isinstance(sz,int):
sz = (sz,sz)
orig_sz = fastuple(_get_sz(x) if orig_sz is None else orig_sz)
sz,tl = fastuple(sz),fastuple(((_get_sz(x)-sz)//2) if tl is None else tl)
if isinstance(x,TensorMask):
return x.float()._do_crop_pad(sz, tl, orig_sz=orig_sz, pad_mode=pad_mode, resize_mode=resize_mode, resize_to=resize_to).long()
else:
return x._do_crop_pad(sz, tl, orig_sz=orig_sz, pad_mode=pad_mode, resize_mode=resize_mode, resize_to=resize_to)
# %% ../../../nbs/vision.augment.itemtensor.ipynb 12
@RandomCrop
def encodes(self, x:TensorImage|TensorMask):
return x.crop_pad(self.size, self.tl, orig_sz=self.orig_sz)
# %% ../../../nbs/vision.augment.itemtensor.ipynb 18
@CropPad
def encodes(self, x:TensorImage|TensorMask):
orig_sz = _get_sz(x)
tl = (orig_sz-self.size)//2
return x.crop_pad(self.size, tl, orig_sz=orig_sz, pad_mode=self.pad_mode)
# %% ../../../nbs/vision.augment.itemtensor.ipynb 22
@Resize
def encodes(self, x:TensorImage|TensorMask):
orig_sz = _get_sz(x)
if self.method==ResizeMethod.Squish:
return x.crop_pad(orig_sz, fastuple(0,0), orig_sz=orig_sz, pad_mode=self.pad_mode,
resize_mode=self.mode_mask if isinstance(x,TensorMask) else self.mode, resize_to=self.size)
w,h = orig_sz
op = (operator.lt,operator.gt)[self.method==ResizeMethod.Pad]
m = w/self.size[0] if op(w/self.size[0],h/self.size[1]) else h/self.size[1]
cp_sz = (int(m*self.size[0]),int(m*self.size[1]))
tl = fastuple(int(self.pcts[0]*(w-cp_sz[0])), int(self.pcts[1]*(h-cp_sz[1])))
return x.crop_pad(cp_sz, tl, orig_sz=orig_sz, pad_mode=self.pad_mode,
resize_mode=self.mode_mask if isinstance(x,TensorMask) else self.mode, resize_to=self.size)
# %% ../../../nbs/vision.augment.itemtensor.ipynb 28
@RandomResizedCrop
def encodes(self, x:TensorImage|TensorMask):
res = x.crop_pad(self.cp_size, self.tl, orig_sz=self.orig_sz,
resize_mode=self.mode_mask if isinstance(x,TensorMask) else self.mode, resize_to=self.final_size)
if self.final_size != self.size: res = res.crop_pad(self.size) #Validation set: one final center crop
return res
# %% ../../../nbs/vision.augment.itemtensor.ipynb 34
@RatioResize
def encodes(self, x:TensorImage|TensorMask):
w,h = _get_sz(x)
if w >= h: nw,nh = self.max_sz,h*self.max_sz/w
else: nw,nh = w*self.max_sz/h,self.max_sz
return Resize(size=(int(nh),int(nw)), resamples=self.resamples)(x)