In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import os
import sys
import random

from skimage.io import imread, imshow
from skimage.transform import resize
from skimage import feature

from skimage.filters import sobel
from skimage.morphology import watershed

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dsets
from torch.autograd import Variable

from torchvision import transforms as tf

import h5py

from pathlib import Path
import nibabel as nib
# from sklearn import preprocessing
from skimage import transform

# from tqdm import tqdm

# from imgaug import augmenters as iaaot as plt
import pandas as pd

In [2]:
import sys
# sys.path.insert(0, '../lib/')
# from help_functions import *

# #function to obtain data for training/testing (validation)
# from extract_patches import get_data_training
# sys.path.insert(0, '../lib/networks/')

from preprocessing import preprocessing


In [3]:
train_img, label_img = preprocessing()


train images/masks shape:
(20, 1, 565, 565)
train images range (min-max): 0.0 - 1.0
train masks are within 0-1

patches per full image: 9500

train PATCHES images/masks shape:
(190000, 1, 48, 48)
train PATCHES images range (min-max): 0.00784313725490196 - 1.0
(190000, 1, 48, 48)
48 48 1
......DONE......
old shape:  (190000, 2304)
new shape:  (190000, 2304, 2)


In [4]:
N_subimgs = 190000
indices = list(range(N_subimgs))
np.random.shuffle(indices)

val_size = 1/10
split = np.int_(np.floor(val_size * N_subimgs))

train_idxs = indices[split:]
val_idxs = indices[:split]

In [5]:
class eye_dataset(torch.utils.data.Dataset):

    def __init__(self,preprocessed_images, train=True, label=None):
        """
        Args:
            text_file(string): path to text file
            root_dir(string): directory with all train images
        """
        self.train = train
        self.images = preprocessed_images
        if self.train:
            self.label = label

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        img = np.zeros_like(image, dtype=np.float32)
        
        img += image
        label = None
        if self.train:
            label = self.label[idx]
#             msk = np.zeros((2,48,48), dtype=np.long)
#             msk[1] = label
#             msk[0] = 1-label
            
#             msk += label
            return (img, label)
        return img

eye_dataset_train = eye_dataset(train_img[train_idxs], 
                                      train=True, 
                                      label=label_img[train_idxs])

eye_dataset_val = eye_dataset(train_img[val_idxs], 
                                      train=True, 
                                      label=label_img[val_idxs])


batch_size = 64

train_loader = torch.utils.data.DataLoader(dataset=eye_dataset_train, 
                                           batch_size=batch_size, 
                                           shuffle=True)

val_loader = torch.utils.data.DataLoader(dataset=eye_dataset_val, 
                                           batch_size=batch_size, 
                                           shuffle=True)

In [6]:
def compute_qkv(chanIn,filters, qkv,layer_type='SAME'):
    '''
    Args:
        inputs: a Tensor with shape [batch,channels, h, w]
        total_key_filters: an integer
        total_value_filters: and integer
        layer_type: String, type of this layer -- SAME, DOWN, UP

    Returns:
        q: [batch, _h, _w, total_key_filters] tensor
        k: [batch, h, w, total_key_filters] tensor
		v: [batch, h, w, total_value_filters] tensor

    
    
    '''
#     print('compute_qkv .........')
    if qkv == 'q':
        # linear transformation for q , filters = key_filters
        if layer_type == 'SAME':
            qkv = nn.Conv2d(chanIn, filters, 1, 1, bias=True,padding=0)
        elif layer_type == 'DOWN':
            qkv = nn.Conv2d(chanIn, filters, 3, 2, bias=True, padding =1)
        elif layer_type == 'UP':
            qkv = nn.ConvTranspose2d(chanIn, filtersx, 3, 2, bias=True, padding=1)
    
    if qkv == 'k':
        # linear transformation for k
        qkv = nn.Conv2d(chanIn, filters, 1, 1, bias=True, padding=0)

    if qkv =='v':
        # linear transformation for k, value filtesr
        qkv = nn.Conv2d(chanIn, filters, 1, 1, bias=True, padding=0)

    return qkv


def split_heads(x, num_heads):
    """Split channels (last dimension) into multiple heads (becomes dimension 1).

    Args:
        x: a Tensor with shape [batch, h, w, channels]
        num_heads: an integer

    Returns:
        a Tensor with shape [batch, num_heads, h, w, channels / num_heads]
    """
    
    return split_last_dimension(x, num_heads).permute(0,4,1,2,3)#permute(0,3,1,2,4)

def split_last_dimension(x,n):
    """Reshape x so that the last dimension becomes two dimensions.
    The first of these two dimensions is n.
    Args:
        x: a Tensor with shape [..., m]
        n: an integer.
    Returns:
        a Tensor with shape [..., n, m/n]
    """
    

    chunk_size = int(x.shape[3]/ n)
    ret = torch.unsqueeze(x,4)
    
    ret = torch.cat(ret.split(split_size=chunk_size, dim=3),4)#.permute(0,1,2,4,3)
#     print('split', ret.shape)
#     ret.view(new_shape)
#     print('split view ', ret.shape)
    return ret

def global_attention(q, k, v, chan_z):
    """global self-attention.
    Args:
        q: a Tensor with shape [batch, heads, _d, _h, _w, channels_k]
        k: a Tensor with shape [batch, heads, d, h, w, channels_k]
        v: a Tensor with shape [batch, heads, d, h, w, channels_v]
        name: an optional string
    Returns:
        a Tensor of shape [batch, heads, _d, _h, _w, channels_v]
    """
    
    
#     print(new_shape)
    # flatten q,k,v
    q_new = flatten(q)
    k_new = flatten(k)
    v_new = flatten(v)

    # attention
    output = dot_product_attention(q_new, k_new, v_new, bias=None,
                dropout_rate=0.5)
#     print('outp ', output.shape)

    # putting the representations back in the right place
#     print('output before scatter', output.shape)
    output = scatter(output, chan_z)

    return output

def dot_product_attention(q, k, v, bias,dropout_rate=0.0):
    """Dot-product attention.
    Args:
        q: a Tensor with shape [batch, heads, length_q, channels_k]
        k: a Tensor with shape [batch, heads, length_kv, channels_k]
        v: a Tensor with shape [batch, heads, length_kv, channels_v]
        bias: bias Tensor
        dropout_rate: a floating point number
        name: an optional string
    Returns:
        A Tensor with shape [batch, heads, length_q, channels_v]
    """

    

    # [batch, num_heads, length_q, length_kv]
#     print(q.shape, k.transpose(2,3).shape)
    logits = torch.matmul(q, k.transpose(2,3))

    if bias is not None:
        logits += bias

    weights = F.softmax(logits)

    # dropping out the attention links for each of the heads
    weights = F.dropout(weights, dropout_rate)

    return torch.matmul(weights, v)


def reshape_range(tensor, i, j, shape):
    """Reshapes a tensor between dimensions i and j."""

    target_shape = tf.concat(
            [tf.shape(tensor)[:i], shape, tf.shape(tensor)[j:]],
            axis=0)

    return tf.reshape(tensor, target_shape)



def scatter(x, chn):
    """scatter x."""
    
#     print('scatter shape in ', x.shape)

    x = x.view(x.shape[0],chn,chn,x.shape[3],-1)
#     print('scatter ', x.shape)

    return x


def flatten(x):
    """flatten x."""
    # [batch, heads, h,w, ch/h]
    # [batch, heads, length, channels], length = d*h*w
    
    l  = x.shape[2] * x.shape[3]
    # [batch, heads, length, channels], length = d*h*w
    x = x.view(x.shape[0], x.shape[1], l,-1)
#     print('flatten shape', x.shape)

    return x


def combine_heads(x):
    """Inverse of split_heads_3d.
    Args:
        x: a Tensor with shape [batch, num_heads, d, h, w, channels / num_heads]
    Returns:
        a Tensor with shape [batch, d, h, w, channels]
    """
#  [0, 2, 3, 4, 1, 5]
#     print('combine heads ,', x.shape)
    return combine_last_two_dimensions(x)


def combine_last_two_dimensions(x):
    """Reshape x so that the last two dimension become one.
    Args:
        x: a Tensor with shape [..., a, b]
    Returns:
        a Tensor with shape [..., a*b]
    """

#     old_shape = x.get_shape().dims
#     a, b = old_shape[-2:]
#     new_shape = old_shape[:-2] + [a * b if a and b else None]

#     ret = tf.reshape(x, tf.concat([tf.shape(x)[:-2], [-1]], 0))
#     ret.set_shape(new_shape)
    
    x = x.contiguous().view(x.shape[0],x.shape[1], x.shape[2], -1)
#     print('combine last two ', x.shape)

    return x

In [9]:
""""This script defines 3D different multi-head attention layers.
(12,12,12,.5,2,False)
"""
class MultiHeadAttention_(nn.Module):
    def __init__(self,
                 chanIn,
                 output_filters,
                 total_key_filters,
                 total_value_filters,                 
                 num_heads,
                 layer_type='SAME'):
        super(MultiHeadAttention_, self).__init__()
        
        '''
        inputs: channels first as input [batch, chn, h,w]
        
        
        '''
        
        """3d Multihead scaled-dot-product attention with input/output transformations.

        Args:
            inputs: a Tensor with shape [batch, h, w, channels]
            
            total_key_filters: an integer. Note that queries have the same number 
                of channels as keys
            total_value_filters: an integer
            output_depth: an integer
            num_heads: an integer dividing total_key_filters and total_value_filters
            layer_type: a string, type of this layer -- SAME, DOWN, UP
            name: an optional string
        Returns:
            A Tensor of shape [batch, _d, _h, _w, output_filters]

        Raises:
            ValueError: if the total_key_filters or total_value_filters are not divisible
                by the number of attention heads.
        """

        if total_key_filters % num_heads != 0:
            raise ValueError("Key depth (%d) must be divisible by the number of "
                            "attention heads (%d)." % (total_key_filters, num_heads))
        if total_value_filters % num_heads != 0:
            raise ValueError("Value depth (%d) must be divisible by the number of "
                            "attention heads (%d)." % (total_value_filters, num_heads))
        if layer_type not in ['SAME', 'DOWN', 'UP']:
            raise ValueError("Layer type (%s) must be one of SAME, "
                            "DOWN, UP." % (layer_type))
            
        
        '''
        inputs: [batch, chn, h,w]
        output: [batch, chn, h,w]
        next step, permute to [batch, h,w, chn]
        '''
        self.q = compute_qkv(chanIn, total_key_filters,'q',layer_type='SAME')
        self.k = compute_qkv(chanIn, total_key_filters,'k',layer_type='SAME')
        self.v = compute_qkv(chanIn, total_value_filters,'v',layer_type='SAME')
        
        self.num_heads = num_heads
        self.total_key_filters = total_key_filters
        self.total_value_filters = total_value_filters
        
        self.conv = nn.Conv2d(total_key_filters, output_filters, 1,1,bias=True)
        
    def forward(self,x):
        
#         print('x before q shape in: ', x.shape)
        #[batch, chn, h,w]
        q = self.q(x)
#         print('q shape in: ', q.shape)
        k = self.k(x)
#         print('k shape in: ', k.shape)
        v = self.v(x)
#         print('v shape in: ', v.shape)
        
        #permute to set [batch, h,w, chn]
        q = q.permute(0,2,3,1)
#         print('q after permute ', q.shape)
        k = k.permute(0,2,3,1)
        v = v.permute(0,2,3,1)
        
#         print('q shape before split ' ,q.shape[3], q.shape)
        q = split_heads(q,self.num_heads)    
#         print('q shape after split ' ,q.shape[3], q.shape)
        k = split_heads(k,self.num_heads)
        v = split_heads(v,self.num_heads)
        #after split [batch, heads, h,w,k/h]
#         print('q split ', q.shape)

        #normalize
        key_filters_per_head = self.total_key_filters // self.num_heads
        q *= key_filters_per_head**-0.5
        
        att = global_attention(q,k,v, q.shape[2])
#         print('out of attt : ', att.shape)



        x = combine_heads(att)
        
#         print('LAST shape in, ' ,x.shape)
        x = x.permute(0,3,1,2)
#         print('LAST shape in, ' ,x.shape)
        x = self.conv(x)
#         print('out of att', x.shape)
        return x


In [10]:
def bn_relu(chanIn, chanOut, ks = 3, stride=1):
    return nn.Sequential(
        nn.BatchNorm2d(chanIn),
        nn.ReLU6(inplace=True),
        nn.Conv2d(chanIn, chanOut, ks, stride, padding=1),
        
    )

class Att_Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_block_1 = bn_relu(1, 1, 3,1)
        self.input_block_2 = bn_relu(1, 1, 3,1)
        self.input_1x1 = nn.Conv2d(1,32,1,1)
        
        self.down1_block_1 = bn_relu(32, 64, 3,2)
        self.down1_block_2 = bn_relu(64,64, 3,1)
        self.down1_1x1 = nn.Conv2d(32, 64, 1,2)
        
        self.down2_block_1 = bn_relu(64, 128, 3,2)
        self.down2_block_2 = bn_relu(128, 128, 3,1)
        self.down2_1x1 = nn.Conv2d(64, 128, 1,2)
        
        self.mid = MultiHeadAttention_(128,128,32,32,4)
        self.bn = nn.BatchNorm2d(128)
        
        self.up2 = nn.Upsample(scale_factor=2)
        self.up2_1 = nn.Conv2d(128,64,3,1,1)
#         self.up2m = MultiHeadAttention_(64,64,12,12,4)
#         self.bn2 = nn.BatchNorm2d(64)
        
        self.up1 = nn.Upsample(scale_factor=2)
        self.up1_1 = nn.Conv2d(64,32,3,1,1)
#         self.up1m = MultiHeadAttention_(32,32,12,12,4)
#         self.bn3 = nn.BatchNorm2d(32)
        
        self.out = nn.Conv2d(32,2,3,1,1)
        
        
        
    def forward(self, x):
#         print('input shape ', x.shape)
        x_top = x
        x = self.input_block_1(x)
        x = self.input_block_2(x)
        x = torch.add(x_top,x) #res1
#         print('x before 1x1 shape: ',x.shape)
        x_res1 = self.input_1x1(x)
#         print('xres1  ', x_res1.shape)
        
        x_l = self.down1_1x1(x_res1)
        x = self.down1_block_1(x_res1)
        x = self.down1_block_2(x)
        x_res2 = torch.add(x_l, x) 
        x_l = self.down2_1x1(x_res2)
        x = self.down2_block_1(x_res2)
        x = self.down2_block_2(x)
        x_res3 = torch.add(x_l,x) 
        
#       res3 shape  torch.Size([32, 128, 12, 12])
        x = self.mid(x_res3)

        x = self.bn(x)
        
        x = self.up2(x)
        x = self.up2_1(x)
        
        x = torch.add(x, x_res2)
        
        x = self.up1(x)
        x = self.up1_1(x)
        x = torch.add(x, x_res1)
        x = self.out(x)
        
        x = x.view(x.shape[0], x.shape[1],-1)
        x = x.permute(0,2,1)
#         print('out : ', x.shape)
        
        
        return x
    
model = Att_Net()
model.cuda()
criterion = nn.BCEWithLogitsLoss()
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, eps=.1,
                             weight_decay=.1)

In [11]:

for n, p in model.named_parameters():
    print(n, p.shape)


input_block_1.0.weight torch.Size([1])
input_block_1.0.bias torch.Size([1])
input_block_1.2.weight torch.Size([1, 1, 3, 3])
input_block_1.2.bias torch.Size([1])
input_block_2.0.weight torch.Size([1])
input_block_2.0.bias torch.Size([1])
input_block_2.2.weight torch.Size([1, 1, 3, 3])
input_block_2.2.bias torch.Size([1])
input_1x1.weight torch.Size([32, 1, 1, 1])
input_1x1.bias torch.Size([32])
down1_block_1.0.weight torch.Size([32])
down1_block_1.0.bias torch.Size([32])
down1_block_1.2.weight torch.Size([64, 32, 3, 3])
down1_block_1.2.bias torch.Size([64])
down1_block_2.0.weight torch.Size([64])
down1_block_2.0.bias torch.Size([64])
down1_block_2.2.weight torch.Size([64, 64, 3, 3])
down1_block_2.2.bias torch.Size([64])
down1_1x1.weight torch.Size([64, 32, 1, 1])
down1_1x1.bias torch.Size([64])
down2_block_1.0.weight torch.Size([64])
down2_block_1.0.bias torch.Size([64])
down2_block_1.2.weight torch.Size([128, 64, 3, 3])
down2_block_1.2.bias torch.Size([128])
down2_block_2.0.weight torc

In [12]:
mean_train_losses = []
mean_val_losses = []

mean_train_acc = []
mean_val_acc = []
minLoss = 99999
maxValacc = -99999
for epoch in range(100):
    print('EPOCH: ',epoch+1)
#     train_losses = []
#     val_losses = []    
    train_acc = []
    val_acc = []
    
    running_loss = 0.0
    
    model.train()
    count = 0
    for images, labels in train_loader:    
#         labels = labels.squeeze()
        images = Variable(images.cuda())
        labels = labels.type(torch.FloatTensor)
        labels = Variable(labels.cuda())
        
#         print(labels.type())
    
#         print(images.shape)
        outputs = model(images) 
#         print(outputs.shape,outputs)
#         print(labels.shape,labels)  
#         print(torch.max(labels, 1)[1])
#         print(images.shape, labels.shape)
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        
#         train_acc.append(accuracy(outputs, labels))
        
        loss.backward()
        optimizer.step()        
        
        running_loss += loss.item()
        count +=1
    
    print('Training loss:.......', running_loss/count)
#     print('Training accuracy:...', np.mean(train_acc))
    mean_train_losses.append(running_loss/count)
        
    model.eval()
    count = 0
    val_running_loss = 0.0
    for images, labels in val_loader:
#         labels = labels.squeeze()
        images = Variable(images.cuda())
        labels = labels.type(torch.FloatTensor)
        labels = Variable(labels.cuda())
        
        outputs = model(images)
        loss = criterion(outputs, labels)

#         val_acc.append(accuracy(outputs, labels))
        val_running_loss += loss.item()
        count +=1

    mean_val_loss = val_running_loss/count
    print('Validation loss:.....', mean_val_loss)
    
#     print('Training accuracy:...', np.mean(train_acc))
#     print('Validation accuracy..', np.mean(val_acc))
    
    mean_val_losses.append(mean_val_loss)
    
#     mean_train_acc.append(np.mean(train_acc))
    
#     val_acc_ = np.mean(val_acc)
#     mean_val_acc.append(val_acc_)
    
   
    if mean_val_loss < minLoss:
        torch.save(model.state_dict(), 'attention_torch_190k.pth' )
        print(f'NEW BEST Loss: {mean_val_loss} ........old best:{minLoss}')
        minLoss = mean_val_loss
        print('')
        
#     if val_acc_ > maxValacc:
#         torch.save(model.state_dict(), 'res/cam_40/best_acc_norm_10x10.pth' )
#         print(f'NEW BEST Acc: {val_acc_} ........old best:{maxValacc}')
#         maxValacc = val_acc_
    
    
    print('')

EPOCH:  1




Training loss:....... 0.2502899529233932
Validation loss:..... 0.22556899232093733
NEW BEST Loss: 0.22556899232093733 ........old best:99999


EPOCH:  2
Training loss:....... 0.2197732792464559
Validation loss:..... 0.22255984084172684
NEW BEST Loss: 0.22255984084172684 ........old best:0.22556899232093733


EPOCH:  3
Training loss:....... 0.22296283318811727
Validation loss:..... 0.26073129841374226

EPOCH:  4
Training loss:....... 0.22228942547149644
Validation loss:..... 0.22401633923904662

EPOCH:  5
Training loss:....... 0.22170294108549635
Validation loss:..... 0.2204093352110699
NEW BEST Loss: 0.2204093352110699 ........old best:0.22255984084172684


EPOCH:  6
Training loss:....... 0.22108966158424131
Validation loss:..... 0.23112892040901312

EPOCH:  7
Training loss:....... 0.22048931363077756
Validation loss:..... 0.22219244513648126

EPOCH:  8
Training loss:....... 0.22061690927377184
Validation loss:..... 0.21760794791308316
NEW BEST Loss: 0.21760794791308316 ........old bes

KeyboardInterrupt: 