In [37]:
from pathlib import Path
from PIL import Image
import numpy as np
import dill as pickle
from src.utils import get_borders
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from typing import Union  
from torchvision import transforms
from functools import partial

In [20]:
class ConvAuto(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.padding =  (self.kernel_size[0] // 2, self.kernel_size[1] // 2) # dynamic add padding based on the kernel_size

        
class MultiConv(nn.Module):
    def __init__(self, in_channels: int = 1, out_channels: int = 32, kernel_sizes: tuple = (7,), bias: bool = True):    
        super(MultiConv, self).__init__()
                
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_sizes = kernel_sizes
        self.conv_dict = {}
        
        convs = []
        for kernel_size in kernel_sizes:
            convs.append(ConvAuto(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias))
        self.convs = torch.nn.Sequential(*convs)
            
        #1x1 convolution to map concatenated output of multiconv back to shape n_kernels
        self.conv_1x1 = nn.Conv2d(in_channels=(out_channels * len(kernel_sizes)), out_channels=out_channels, kernel_size=1, bias=True)
            
    def forward(self, x):
        x_ = torch.cat([c(x) for c in self.convs], dim=1)
        return self.conv_1x1(x_)
        

class CNNBaseMulti(nn.Module):
    def __init__(self, n_in_channels: int = 1, n_hidden_layers: int = 3, n_kernels: int = 32, kernel_sizes: Union[tuple, int] = (7,), batch_norm: bool = False, kernel_size_out: int = 7):
        """Simple CNN with `n_hidden_layers`, `n_kernels`, and `kernel_size` as hyperparameters"""
        super(CNNBase, self).__init__()
        
        if type(kernel_sizes)==int:
            kernel_sizes = tuple([kernel_sizes])
        elif type(kernel_sizes)==list:
            kernel_sizes = tuple(kernel_sizes)
        
        cnn = []
        for i in range(n_hidden_layers):
            cnn.append(MultiConv(in_channels=n_in_channels, out_channels=n_kernels, kernel_sizes=kernel_sizes, bias=True))
            if batch_norm:
                cnn.append(nn.BatchNorm2d(n_kernels))
            cnn.append(torch.nn.ReLU())
            n_in_channels = n_kernels
        self.hidden_layers = torch.nn.Sequential(*cnn)

        self.output_layer = torch.nn.Conv2d(in_channels=n_in_channels, out_channels=1,
                                            kernel_size=kernel_size_out, bias=True, padding=int(kernel_size_out/2))
    
    def forward(self, x):
        """Apply CNN to input `x` of shape (N, n_channels, X, Y), where N=n_samples and X, Y are spatial dimensions"""
        cnn_out = self.hidden_layers(x)  # apply hidden layers (N, n_in_channels, X, Y) -> (N, n_kernels, X, Y)
        pred = self.output_layer(cnn_out)  # apply output layer (N, n_kernels, X, Y) -> (N, 1, X, Y)
        return pred

In [5]:
input_dir = "data"

In [16]:
imgs = []
tf = transforms.Compose([
                transforms.ToTensor(),
                transforms.Grayscale(),
                transforms.Resize((90,90))
])
for f in Path(input_dir).rglob("*.jpg"):
    imgs.append(tf(Image.open(f)))

In [17]:
img = imgs[0]

In [19]:
img

tensor([[[0.2279, 0.2109, 0.1660,  ..., 0.2700, 0.2864, 0.1379],
         [0.2030, 0.1680, 0.1920,  ..., 0.0698, 0.2075, 0.4795],
         [0.2294, 0.1424, 0.1086,  ..., 0.3743, 0.3833, 0.4518],
         ...,
         [0.3333, 0.2233, 0.3995,  ..., 0.4441, 0.2631, 0.1029],
         [0.3925, 0.3777, 0.2783,  ..., 0.4684, 0.2546, 0.1677],
         [0.3905, 0.3388, 0.3239,  ..., 0.4263, 0.2075, 0.1802]]])

In [21]:
conv1 = MultiConv(in_channels = 1, out_channels = 32, kernel_sizes = (7,9,11), bias = True)
conv2 = MultiConv(in_channels = 32, out_channels = 32, kernel_sizes = (5,7,9), bias = True)
conv3 = MultiConv(in_channels = 32, out_channels = 32, kernel_sizes = (5,7,9), bias = True)

In [28]:
x = img.unsqueeze(0)

In [26]:
in_channels = 1
out_channels = 32

In [None]:
class Inception(nn.Module):
    def __init__(self, in_channels: int = 1, out_channels: int = 32, kernel_sizes: tuple = (5,7), bias: bool = True):    
        super(MultiConv, self).__init__()
                
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_sizes = kernel_sizes
        self.conv_dict = {}
        
        convs = []
        for kernel_size in kernel_sizes:
            convs.append(ConvAuto(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias))
        self.convs = torch.nn.Sequential(*convs)
            
        #1x1 convolution to map concatenated output of multiconv back to shape n_kernels
        self.conv_1x1 = nn.Conv2d(in_channels=(out_channels * len(kernel_sizes)), out_channels=out_channels, kernel_size=1, bias=True)
            
    def forward(self, x):
        x_ = torch.cat([c(x) for c in self.convs], dim=1)
        return self.conv_1x1(x_)

In [74]:
conv_blocks = nn.Sequential(
    conv_layer(1), #conv_1x1
    nn.Sequential(
        ConvAuto(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=True),
        ConvAuto(in_channels=out_channels, out_channels=out_channels, kernel_size=3, bias=True)
    ), #conv_3x3
    nn.Sequential(
        ConvAuto(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=True),
        ConvAuto(in_channels=out_channels, out_channels=out_channels, kernel_size=5, bias=True)
    ), #conv_5x5
    nn.Sequential(
        nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
        ConvAuto(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=True)
    ) #max_pool_3x3   
)

filter_concat = ConvAuto(in_channels=out_channels * 4, out_channels=out_channels, kernel_size=1, bias=True)

In [75]:
x_ = torch.cat([c(x) for c in conv_blocks], dim=1)

In [76]:
x_.shape

torch.Size([1, 128, 90, 90])

In [78]:
filter_concat(x_).shape

torch.Size([1, 32, 90, 90])

In [65]:
max_pool_3x3(x).shape

torch.Size([1, 32, 90, 90])

In [59]:
m = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)

In [60]:
m(x).shape

torch.Size([1, 1, 90, 90])

In [67]:
conv_blocks

Sequential(
  (0): ConvAuto(1, 32, kernel_size=(1, 1), stride=(1, 1))
  (1): Sequential(
    (0): ConvAuto(1, 32, kernel_size=(1, 1), stride=(1, 1))
    (1): ConvAuto(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (2): Sequential(
    (0): ConvAuto(1, 32, kernel_size=(1, 1), stride=(1, 1))
    (1): ConvAuto(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  )
  (3): Sequential(
    (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
    (1): ConvAuto(1, 32, kernel_size=(1, 1), stride=(1, 1))
  )
)

In [None]:
    tf_aug = transforms.Compose([
                transforms.ToTensor(),
                transforms.RandomResizedCrop((90,90)),
                transforms.RandomHorizontalFlip(.33),
                transforms.RandomVerticalFlip(.33)
                #transforms.Normalize((mean,), (std,))
            ])
    

In [96]:
tf_list = [
    transforms.Lambda(lambda x: x)
    transforms.Lambda(lambda x: torch.rot90(x, 1, dims=(2,3))),
    transforms.Lambda(lambda x: torch.rot90(x, 2, dims=(2,3))),
    transforms.Lambda(lambda x: torch.rot90(x, 3, dims=(2,3))),
    transforms.Lambda(lambda x: torch.flip(x, dims=(2,))),
    transforms.Lambda(lambda x: torch.flip(x, dims=(3,)))
]

In [87]:
fl = transforms.RandomHorizontalFlip(1)

In [100]:
idx = torch.randint(0, len(tf_list), size=(1,)).item()

In [112]:
def random_rotation_flip(x):
    """only works for square tensors"""
    tf_list = [
        lambda x: x,
        lambda x: torch.rot90(x, 1, dims=(2,3)),
        lambda x: torch.rot90(x, 2, dims=(2,3)),
        lambda x: torch.rot90(x, 3, dims=(2,3)),
        lambda x: torch.flip(x, dims=(2,)),
        lambda x: torch.flip(x, dims=(3,))
    ]
    idx = torch.randint(0, len(tf_list), size=(1,)).item()
    return tf_list[idx]

In [115]:
tf = transforms.Lambda(random_rotation_flip(x))

In [116]:
tf(x)

tensor([[[[0.3905, 0.3925, 0.3333,  ..., 0.2294, 0.2030, 0.2279],
          [0.3388, 0.3777, 0.2233,  ..., 0.1424, 0.1680, 0.2109],
          [0.3239, 0.2783, 0.3995,  ..., 0.1086, 0.1920, 0.1660],
          ...,
          [0.4263, 0.4684, 0.4441,  ..., 0.3743, 0.0698, 0.2700],
          [0.2075, 0.2546, 0.2631,  ..., 0.3833, 0.2075, 0.2864],
          [0.1802, 0.1677, 0.1029,  ..., 0.4518, 0.4795, 0.1379]]]])

In [108]:
x

tensor([[[[0.2279, 0.2109, 0.1660,  ..., 0.2700, 0.2864, 0.1379],
          [0.2030, 0.1680, 0.1920,  ..., 0.0698, 0.2075, 0.4795],
          [0.2294, 0.1424, 0.1086,  ..., 0.3743, 0.3833, 0.4518],
          ...,
          [0.3333, 0.2233, 0.3995,  ..., 0.4441, 0.2631, 0.1029],
          [0.3925, 0.3777, 0.2783,  ..., 0.4684, 0.2546, 0.1677],
          [0.3905, 0.3388, 0.3239,  ..., 0.4263, 0.2075, 0.1802]]]])