In [None]:
%load_ext autoreload
%autoreload 2

import markdown
import torch
import madpack

from IPython.core.display import display, HTML
from madpack.doc import html_docstring, html_docstring_class
from madpack import log
log.level = 'info'

In [None]:
display(HTML(markdown.markdown(open('../Readme.md', 'r').read())))

# Datasets

## DatasetBase class

All dataset in madpack inherit from DatasetBase. Lets take a look at its methods.

In [None]:
from madpack.datasets import DatasetBase
display(HTML(html_docstring_class(DatasetBase)))

In [None]:
for dataset_type in madpack.datasets.__all__:
    if dataset_type != 'DatasetBase':
        dt = getattr(madpack.datasets, dataset_type)
        display(HTML(html_docstring_class(dt, exclude=('install', '__getitem__', 'check_data_integrity'))))

## Dataset usage

In [None]:
from madpack.datasets import SquareCountDummy
dset = SquareCountDummy('train')

Datasets can define repository files that are copied to a local folder when required. Per default, the dataset path is `~/datasets` and the repository path is `~/dataset_repositories`, both can be symlinks. They can be overwritten in a config file `~/.config/madpack.yaml`.

In [None]:
dset.repository_files

The tuple or list `sample_ids` assigns a unique identifier to each sample of the dataset. In this case it is a parameterization of the squares.

In [None]:
dset.sample_ids[0]

When a dataset implements the attribute tuple `sample_ids` it is used to define length. Furthermore, the `resize` option becomes available. Using `split_overlap` the splits can be checked for overlaps.

In [None]:
print('original dataset size', len(dset))

dset.resize(100)
print('reduced to', len(dset))

dset.resize(None)
print('and back at', len(dset))

print("Now let's check for overlap...!")
from madpack.utils import split_overlap
split_overlap(SquareCountDummy)

### Visualizing Datasets

In [None]:
from madpack.interactive import *
plot_data(dset, shuffle=True)

# Transforms

It is recommendable too rely on `torchvision`'s transforms as much as possible.

In [None]:
for transform in madpack.transforms.__all__:
    display(HTML(html_docstring_class(getattr(madpack.transforms, transform))))

# Models

In [None]:
from madpack.models import RN18Dense, RN18Narrow
import torch

inp = torch.zeros(1, 3, 128, 128)
out, activations =  RN18Dense()(inp)

assert out.shape[2:] == (inp.shape[2:])

In [None]:
from madpack.models import NarrowRN18Dense, RN18Dense

from madpack.utils import count_parameters
m = NarrowRN18Dense(channels=(32, 32, 16, 16), decoder_shape='xs')
print(count_parameters(m), count_parameters(m.resnet), count_parameters(m.decoder2))
import torch

inp = torch.zeros(1,3,48,48)
out = m(inp)

assert inp.shape[2:] == out[0].shape[2:]
assert out[0].shape[1] == 10

In [None]:
from madpack.models import NarrowRN50Dense
m = NarrowRN50Dense(channels=(16,16,16,16))

out = m(inp)

assert inp.shape[2:] == out[0].shape[2:]
assert out[0].shape[1] == 10

# Transforms

In [None]:
from madpack.transforms import imread
from madpack.interactive import *

img = imread('sample_image1.jpg')
plt.imshow(img.permute(1,2,0))

Scale proportionally to size defined by box (here 200 by 200).

In [None]:
img_part = img[:, 100:200, 50:110]
img_part.shape

print(img_part.shape)

from madpack.transforms import resize
out = resize(img_part, (200, 200), max_bound=True)

plt.imshow(out.permute(1,2,0))

### random crop

In [None]:
from madpack.transforms.spatial import random_crop

print('before', img.shape)
img_crop = random_crop(img, (150, 150), spatial_dims=(1,2))
print('after', img_crop.shape)
plt.imshow(img_crop.permute(1,2,0))

### pad

In [None]:
from madpack.transforms.spatial import pad_to_square
img_square = pad_to_square(img[:, :100, :200], channel_dim=0)
img_square2 = pad_to_square(img[:, :100, :200].permute(1,2,0), channel_dim=2)
assert img_square.shape[1:] == img_square2.shape[0:2]
plt.imshow(img_square2)

### random crop containing a selected area

generate crops (light blue) that encompass the yellow square. 


In [None]:
from madpack.transforms.spatial import random_crop_special_by_map
from matplotlib import pyplot as plt
import torch

images = []
_, ax = plt.subplots(1, 5, figsize=(15, 3))
a = torch.zeros(200, 200).bool()
pos = torch.randint(0, 200-30, (2,))

# print(pos)
a[pos[0]: pos[0]+ 30, pos[1]: pos[1]+30] = 1

for i in range(5):
    off_y, off_x, size, iters = random_crop_special_by_map(a, (80,40))
    b = torch.zeros_like(a).byte()
    b[off_y: off_y + size[0], off_x: off_x + size[1]] = 1
    b += 3*a.byte()
    # print(b.shape)
    images += [b]

    ax[i].imshow(b)
    ax[i].axis('off')

### adaptive size

In [None]:
_, ax = plt.subplots(1, 5, figsize=(15, 3))
for i in range(5):
    off_y, off_x, size, iters = random_crop_special_by_map(a, (80,20), adapt_size=True)
    b = torch.zeros_like(a).byte()
    b[off_y: off_y + size[0], off_x: off_x + size[1]] = 1
    b += 3*a.byte()
    # print(b.shape)
    images += [b]

    ax[i].imshow(b)
    ax[i].axis('off')