In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from collections import OrderedDict
import numpy as np
import math

In [2]:
def ceil(v):
        if v == int(v): return int(v)
        else: return int(v+1)
in_ch = 3
out_ch = 3
ncvar=46
norm_chs = 4 * ceil(out_ch/4)
use_sam = True
use_cam = True
use_ele = True
ret_sam = False
scale = 2
cvar_ch=8
relu_a=0.01
stage_chs=(256, 128, 64)

In [4]:
# Channel attention module 
class CAM(torch.nn.Module):
    def __init__(self, in_ch, relu_a=0.01, r=2):
        super().__init__()
        self.mlp_ops = [
            torch.nn.Linear(in_ch, in_ch//r),
            torch.nn.LeakyReLU(negative_slope=relu_a), 
            torch.nn.Linear(in_ch//r, in_ch),
        ]
        
        self.amlp_layer = torch.nn.Sequential(*self.mlp_ops)
        self.out_act    = torch.nn.Sigmoid()
        
    def forward(self, x, ret_att=False):
        _max_out, _ = torch.max(x, 2, keepdim=False)
        _max_out, _ = torch.max(_max_out, -1, keepdim=False)
        
        _avg_out    = torch.mean(x, 2, keepdim=False)
        _avg_out    = torch.mean(_avg_out, -1, keepdim=False)
        
        _mlp_max    = _max_out
        for layer in self.amlp_layer:
            _mlp_max = layer(_mlp_max)
            
        _mlp_avg    = _avg_out
        for layer in self.amlp_layer:
            _mlp_avg = layer(_mlp_avg)
            
        _attention = self.out_act(_mlp_avg + _mlp_max)
        _attention = _attention.unsqueeze(-1)
        _attention = _attention.unsqueeze(-1)
   
        if ret_att:
            return _attention, _attention * x
        else:
            return _attention * x


# Spatial attention module 
class SAM(torch.nn.Module):
    def __init__(self, in_ch, relu_a=0.01):
        super().__init__()
        self.cnn_ops = [
            torch.nn.Conv2d(in_channels=2, out_channels=1, \
                            kernel_size=7, padding=3),
            torch.nn.Sigmoid(), ] # use Sigmoid to norm to [0, 1]
        
        self.attention_layer = torch.nn.Sequential(*self.cnn_ops)
        
    def forward(self, x, ret_att=False):
        _max_out, _ = torch.max(x, 1, keepdim=True)
        _avg_out    = torch.mean(x, 1, keepdim=True)
        _out = torch.cat((_max_out, _avg_out), dim=1)
        _attention = _out
        for layer in self.attention_layer:
            _attention = layer(_attention)
           
        if ret_att:
            return _attention, _attention * x
        else:
            return _attention * x


class inception_box(torch.nn.Module):
    def __init__(self, in_ch, o_ch, relu_a=0.01):
        super().__init__()
        assert o_ch % 4 == 0
        self.conv1b1_ops = [
            torch.nn.Conv2d(in_channels=in_ch, out_channels=o_ch//4, kernel_size=1, \
                            stride=1, padding=0),
            torch.nn.LeakyReLU(negative_slope=relu_a), ]
        
        self.conv3b3_ops = [
            torch.nn.Conv2d(in_channels=in_ch, out_channels=o_ch//4, kernel_size=1, \
                            stride=1, padding=0),
            torch.nn.LeakyReLU(negative_slope=relu_a), 
            torch.nn.Conv2d(in_channels=o_ch//4, out_channels=o_ch//4, kernel_size=3, \
                            stride=1, padding=1),
            torch.nn.LeakyReLU(negative_slope=relu_a), ]
        
        self.conv5b5_ops = [
            torch.nn.Conv2d(in_channels=in_ch, out_channels=o_ch//4, kernel_size=1, \
                            stride=1, padding=0),
            torch.nn.LeakyReLU(negative_slope=relu_a), 
            torch.nn.Conv2d(in_channels=o_ch//4, out_channels=o_ch//4, kernel_size=5, \
                            stride=1, padding=2),
            torch.nn.LeakyReLU(negative_slope=relu_a), ]
        
        self.maxpool_ops = [
            torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            torch.nn.Conv2d(in_channels=in_ch, out_channels=o_ch//4, kernel_size=1, \
                            stride=1, padding=0),
            torch.nn.LeakyReLU(negative_slope=relu_a), ]
        
        self.conv1b1 = torch.nn.Sequential(*self.conv1b1_ops)
        self.conv3b3 = torch.nn.Sequential(*self.conv3b3_ops)
        self.conv5b5 = torch.nn.Sequential(*self.conv5b5_ops)
        self.maxpool = torch.nn.Sequential(*self.maxpool_ops)
        
    def forward(self, x): 
        _out_conv1b1 = x
        for layer in self.conv1b1:
            _out_conv1b1 = layer(_out_conv1b1)
            
        _out_conv3b3 = x
        for layer in self.conv3b3:
            _out_conv3b3 = layer(_out_conv3b3)
            
        _out_conv5b5 = x
        for layer in self.conv5b5:
            _out_conv5b5 = layer(_out_conv5b5)
            
        _out_maxpool = x
        for layer in self.conv1b1:
            _out_maxpool = layer(_out_maxpool)
            
        return torch.cat([_out_conv1b1, _out_conv3b3, _out_conv5b5, _out_maxpool], 1)

In [5]:
in_norm_ops = [
    torch.nn.Conv2d(in_channels=out_ch, out_channels=norm_chs, \
                    kernel_size=1, stride=1, padding=0),
    torch.nn.BatchNorm2d(num_features=norm_chs),
    torch.nn.LeakyReLU(negative_slope=relu_a), ]

up1_ops = [
    torch.nn.ConvTranspose2d(in_channels=stage_chs[0]+cvar_ch*ncvar, out_channels=stage_chs[0], \
                                    kernel_size=2, stride=2, padding=0),
    torch.nn.LeakyReLU(negative_slope=0.01), ]

up2_ops = [
    torch.nn.ConvTranspose2d(in_channels=stage_chs[1], out_channels=stage_chs[1], \
                                    kernel_size=2, stride=2, padding=0),
    torch.nn.LeakyReLU(negative_slope=0.01), ]
    
if use_ele:
    ele_ops = [
        torch.nn.Conv2d(in_channels=1, out_channels=4, \
                        kernel_size=3, stride=1, padding=1),
        torch.nn.LeakyReLU(negative_slope=0.01), 
        torch.nn.Conv2d(in_channels=4, out_channels=8, \
                        kernel_size=3, stride=1, padding=1),
        torch.nn.LeakyReLU(negative_slope=0.01), ]

out_ops = [
    torch.nn.Conv2d(in_channels=stage_chs[2], out_channels=4, \
                    kernel_size=3, stride=1, padding=1),
    torch.nn.BatchNorm2d(num_features=4),
    torch.nn.LeakyReLU(negative_slope=0.01), 
    torch.nn.Conv2d(in_channels=4, out_channels=out_ch, # was out_channels=1,
                    kernel_size=3, stride=1, padding=1),]


cvar_inceps = [torch.nn.ModuleList([inception_box(in_ch=1, o_ch=cvar_ch), \
                        inception_box(in_ch=cvar_ch, o_ch=cvar_ch), \
                        inception_box(in_ch=cvar_ch, o_ch=cvar_ch), \
                        inception_box(in_ch=cvar_ch, o_ch=cvar_ch)]) for _ in range(ncvar)]
cvar_inceps = torch.nn.ModuleList(cvar_inceps)

ich_layers = torch.nn.Sequential(*in_norm_ops)

p1_inception1 = inception_box(in_ch = norm_chs, o_ch=stage_chs[0])
p1_inception2 = inception_box(in_ch = stage_chs[0], o_ch = stage_chs[0])
p1_inception3 = inception_box(in_ch = stage_chs[0], o_ch = stage_chs[0])
p1_inception4 = inception_box(in_ch = stage_chs[0], o_ch = stage_chs[0])
up1_layers    = torch.nn.Sequential(*up1_ops)

p2_inception1 = inception_box(in_ch = stage_chs[0], o_ch = stage_chs[1])
p2_inception2 = inception_box(in_ch = stage_chs[1], o_ch = stage_chs[1])
p2_inception3 = inception_box(in_ch = stage_chs[1], o_ch = stage_chs[1])
p2_inception4 = inception_box(in_ch = stage_chs[1], o_ch = stage_chs[1])
up2_layers    = torch.nn.Sequential(*up2_ops)

if use_cam:
    up1_cam = CAM(in_ch = stage_chs[0] + cvar_ch*ncvar)
    up2_cam = CAM(in_ch = stage_chs[1])

if use_sam:
    up1_sam = SAM(in_ch = stage_chs[0])
    up2_sam = SAM(in_ch = stage_chs[1])

if use_ele:
    ele_layers = torch.nn.Sequential(*ele_ops)
    p3_inception1 = inception_box(in_ch = 8+stage_chs[1], o_ch = stage_chs[2])
else:
    p3_inception1 = inception_box(in_ch = stage_chs[1], o_ch = stage_chs[2])
p3_inception2 = inception_box(in_ch = stage_chs[2], o_ch = stage_chs[2])
p3_inception3 = inception_box(in_ch = stage_chs[2], o_ch = stage_chs[2])
p3_inception4 = inception_box(in_ch = stage_chs[2], o_ch = stage_chs[2])
out_layers = torch.nn.Sequential(*out_ops)

In [6]:
x = torch.randn(64, 49, 32, 64)
elev = torch.randn(64, 1, 128, 256)

In [7]:
# Conditional variables as list of tensors
cvars = [x[...,i:i+1,:, :] for i in range(out_ch, x.shape[1])]
x = x[...,:out_ch,:, :]


In [8]:
assert len(cvars) == len(cvar_inceps)
cvar_outs = []
for _cf, cvar in zip(cvar_inceps, cvars):
    _tmp = cvar
    for _f in _cf:
        _tmp = _f(_tmp)
    cvar_outs.append(_tmp)

In [9]:
cvar_outs[0].shape

torch.Size([64, 8, 32, 64])

In [10]:
x.shape

torch.Size([64, 3, 32, 64])

In [11]:
out_tmp = x
for layer in ich_layers:
    out_tmp = layer(out_tmp) 
    


In [12]:
out_tmp.shape

torch.Size([64, 4, 32, 64])

In [13]:
out_tmp = p1_inception1(out_tmp)

In [14]:
out_tmp.shape

torch.Size([64, 256, 32, 64])

In [15]:
out_tmp = p1_inception2(out_tmp)
out_tmp.shape

torch.Size([64, 256, 32, 64])

In [16]:
out_tmp = p1_inception3(out_tmp)
out_tmp.shape

torch.Size([64, 256, 32, 64])

In [17]:
out_tmp = p1_inception4(out_tmp)
out_tmp.shape

torch.Size([64, 256, 32, 64])

In [18]:
if use_sam: # apply spatial attention 
    if ret_sam:
        atten1, out_tmp = up1_sam(out_tmp, ret_att=True) 
    else:
        out_tmp = up1_sam(out_tmp) 

In [19]:
out_tmp.shape

torch.Size([64, 256, 32, 64])

In [20]:
out_tmp = torch.cat([out_tmp,] + cvar_outs, 1)

In [21]:
cvar_outs[0].shape

torch.Size([64, 8, 32, 64])

In [22]:
out_tmp.shape

torch.Size([64, 624, 32, 64])

In [23]:
if use_cam:
    out_tmp = up1_cam(out_tmp) # apply channel attention 

In [24]:
out_tmp.shape

torch.Size([64, 624, 32, 64])

In [25]:
for layer in up1_layers:
    out_tmp = layer(out_tmp)  

In [26]:
out_tmp.shape

torch.Size([64, 256, 64, 128])

In [27]:
out_tmp = p2_inception1(out_tmp)
out_tmp.shape

torch.Size([64, 128, 64, 128])

In [28]:
out_tmp = p2_inception2(out_tmp)
out_tmp.shape

torch.Size([64, 128, 64, 128])

In [29]:
out_tmp = p2_inception3(out_tmp)
out_tmp.shape

torch.Size([64, 128, 64, 128])

In [30]:
out_tmp = p2_inception4(out_tmp)
out_tmp.shape

torch.Size([64, 128, 64, 128])

In [31]:
if use_cam:
    out_tmp = up2_cam(out_tmp) # apply channel attention 
    
if use_sam: # apply spatial attention 
    if ret_sam:
        atten2, out_tmp = up2_sam(out_tmp, ret_att=True) 
    else:
        out_tmp = up2_sam(out_tmp) 

for layer in up2_layers:
    out_tmp = layer(out_tmp)  

In [32]:
out_tmp.shape

torch.Size([64, 128, 128, 256])

In [33]:
if elev is not None:
    ele_tmp = elev
    for layer in ele_layers:
        ele_tmp = layer(ele_tmp)  
    out_tmp = torch.cat([out_tmp, ele_tmp], 1)

In [34]:
out_tmp.shape

torch.Size([64, 136, 128, 256])

In [35]:
out_tmp = p3_inception1(out_tmp)
out_tmp.shape

torch.Size([64, 64, 128, 256])

In [36]:
out_tmp = p3_inception2(out_tmp)
out_tmp.shape

torch.Size([64, 64, 128, 256])

In [37]:
out_tmp = p3_inception3(out_tmp)
out_tmp.shape

torch.Size([64, 64, 128, 256])

In [38]:
out_tmp = p3_inception4(out_tmp)
out_tmp.shape

torch.Size([64, 64, 128, 256])

In [39]:
for layer in out_layers:
    out_tmp = layer(out_tmp)
    print(out_tmp.shape)

torch.Size([64, 4, 128, 256])
torch.Size([64, 4, 128, 256])
torch.Size([64, 4, 128, 256])
torch.Size([64, 3, 128, 256])


In [56]:
out_tmp[:,:1,:, :].shape

torch.Size([64, 1, 128, 256])

In [122]:
tst = torch.randn(64, 3, 1024, 2048)

In [123]:
operations = [
    torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1),
    torch.nn.LeakyReLU(0.2), ]

# Intermediate layers (C128-C256-C512-C512)
out_chs = (128, 256, 512, 512)
in_chs = (64, ) + out_chs[:-1]
for ic, oc in zip(in_chs, out_chs):
    operations += [
        torch.nn.Conv2d(in_channels=ic, out_channels=oc, kernel_size=4, stride=2, padding=1),
        torch.nn.BatchNorm2d(oc),
        torch.nn.LeakyReLU(0.2), ]

# Global Average Pooling
global_pool = torch.nn.AdaptiveAvgPool2d((1, 1))

# Output layer that outputs binary logits
fc_layers = nn.Sequential(
            nn.Linear(out_chs[-1], out_chs[-1] // 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(p=0.2),
            nn.Linear(out_chs[-1] // 4, 1)
        )

layers = torch.nn.Sequential(*operations)

In [124]:
test = layers(tst)
print(test.shape)
test = global_pool(test)
print(test.shape)
test = test.view(x.size(0), -1)
print(test.shape)
test = fc_layers(test)
print(test.shape)


torch.Size([64, 512, 32, 64])
torch.Size([64, 512, 1, 1])
torch.Size([64, 512])
torch.Size([64, 1])


In [126]:
output.shape

torch.Size([64, 1, 2, 4])

In [104]:
dsc_out_size= (32, 1, 2, 4)
true_label  = torch.ones (dsc_out_size)
false_label = torch.zeros(dsc_out_size)
disc_label  = torch.cat((true_label, false_label), dim=0)

In [105]:
disc_label.shape

torch.Size([64, 1, 2, 4])

In [134]:
torch.nn.BCELoss()(torch.ones(1,1), torch.zeros(1,1))

tensor(100.)

In [98]:
output = layers(out_tmp)

In [94]:
with torch.no_grad():
    tst2 = layers(out_tmp)

In [95]:
tst2.mean().log()

tensor(-0.7115)

In [85]:
with torch.no_grad():
    advs_loss = 0 - layers(out_tmp).mean().log()

In [86]:
advs_loss

tensor(nan)

In [78]:
output.shape

torch.Size([64, 1, 4, 4])

In [61]:
output[0]

tensor([[[ 0.1540, -0.1572,  0.2223, -0.0017],
         [ 0.0184,  0.0606,  0.2462, -0.5188]]], grad_fn=<SelectBackward0>)

In [67]:
output.shape

torch.Size([64, 1, 2, 4])

In [53]:
output[0]

tensor([[[0.4419, 0.4992, 0.3913, 0.4340],
         [0.4333, 0.5574, 0.5085, 0.5064]]], grad_fn=<SelectBackward0>)

In [57]:
ceil(3)

3

In [47]:
in_norm_ops = [
    torch.nn.Conv2d(in_channels=in_ch, out_channels=norm_chs,
                    kernel_size=1, stride=1, padding=0),
    torch.nn.BatchNorm2d(num_features=norm_chs),
    torch.nn.LeakyReLU(negative_slope=relu_a), ]

in_norm_layers = torch.nn.Sequential(*in_norm_ops)

up_layers = torch.nn.ModuleList()
inception_layers = torch.nn.ModuleList()

current_in_ch = norm_chs
for i in range(len(stage_chs)):
    up_ops = [
        torch.nn.ConvTranspose2d(in_channels=current_in_ch, out_channels=stage_chs[i],
                                    kernel_size=2, stride=2, padding=0),
        torch.nn.LeakyReLU(negative_slope=relu_a), ]
    up_layers.append(torch.nn.Sequential(*up_ops))

    inception_block = torch.nn.Sequential(
        inception_box(in_ch=stage_chs[i], o_ch=stage_chs[i]),
        inception_box(in_ch=stage_chs[i], o_ch=stage_chs[i]),
        inception_box(in_ch=stage_chs[i], o_ch=stage_chs[i]),
        inception_box(in_ch=stage_chs[i], o_ch=stage_chs[i])
    )
    inception_layers.append(inception_block)
    current_in_ch = stage_chs[i]

if use_ele:
    ele_ops = [
        torch.nn.Conv2d(in_channels=1, out_channels=4,
                        kernel_size=3, stride=1, padding=1),
        torch.nn.LeakyReLU(negative_slope=0.01),
        torch.nn.Conv2d(in_channels=4, out_channels=8,
                        kernel_size=3, stride=1, padding=1),
        torch.nn.LeakyReLU(negative_slope=0.01), ]
    ele_layers = torch.nn.Sequential(*ele_ops)
    current_in_ch += 8

out_ops = [
    torch.nn.Conv2d(in_channels=current_in_ch, out_channels=4,
                    kernel_size=3, stride=1, padding=1),
    torch.nn.BatchNorm2d(num_features=4),
    torch.nn.LeakyReLU(negative_slope=0.01),
    torch.nn.Conv2d(in_channels=4, out_channels=in_ch,
                    kernel_size=3, stride=1, padding=1), ]
out_layers = torch.nn.Sequential(*out_ops)

cvar_inceps = [torch.nn.ModuleList([inception_box(in_ch=1, o_ch=cvar_ch),
                                            inception_box(in_ch=cvar_ch, o_ch=cvar_ch),
                                            inception_box(in_ch=cvar_ch, o_ch=cvar_ch),
                                            inception_box(in_ch=cvar_ch, o_ch=cvar_ch)]) for _ in range(ncvar)]
cvar_inceps = torch.nn.ModuleList(cvar_inceps)

if use_cam:
    cam_layers = torch.nn.ModuleList([CAM(in_ch=stage_chs[i] + cvar_ch * ncvar if i == 0 else stage_chs[i]) for i in range(len(stage_chs))])

if use_sam:
    sam_layers = torch.nn.ModuleList([SAM(in_ch=stage_chs[i]) for i in range(len(stage_chs))])


In [68]:
# Conditional variables as list of tensors
cvars = [x[..., i:i + 1, :, :] for i in range(in_ch, x.shape[1])]
x = x[..., :in_ch, :, :]


In [69]:
class discModel(torch.nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        # input layer
        self.operations = [
            torch.nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1),
            torch.nn.LeakyReLU(0.2), ]  
        # C128-C256-C512-C512
        out_chs = (128, 256, 512, 512, )
        in_chs  = (64, ) + out_chs[:-1]
        for ic, oc in zip(in_chs, out_chs):
            self.operations += [
                torch.nn.Conv2d(in_channels=ic, out_channels=oc, kernel_size=4, \
                                stride=2, padding=1),
                torch.nn.BatchNorm2d(oc),
                torch.nn.LeakyReLU(0.2), ]
            
        # output layers
        self.operations += [
            torch.nn.Conv2d(in_channels=out_chs[-1], out_channels=1, kernel_size=4, stride=2, padding=1),
            # torch.nn.Sigmoid(),  # comment this line for BCEWithLogitsLoss
            ]
        
        self.layers = torch.nn.Sequential(*self.operations)
        
    def forward(self, x):
        return self.layers(x)

In [73]:
disc = discModel(in_channels=3)

In [75]:
x.shape

torch.Size([64, 3, 32, 64])

In [49]:
assert len(cvars) == len(cvar_inceps)
cvar_outs = []
for _cf, cvar in zip(cvar_inceps, cvars):
    _tmp = cvar
    for _f in _cf:
        _tmp = _f(_tmp)
    cvar_outs.append(_tmp)


In [51]:
cvar_outs[0].shape

torch.Size([64, 8, 32, 64])

In [50]:
x.shape

torch.Size([64, 3, 32, 64])

In [52]:
out_tmp = x
for layer in in_norm_layers:
    out_tmp = layer(out_tmp)
    print(out_tmp.shape)


torch.Size([64, 4, 32, 64])
torch.Size([64, 4, 32, 64])
torch.Size([64, 4, 32, 64])


In [53]:
out_tmp.shape

torch.Size([64, 4, 32, 64])

In [54]:
len(up_layers)

3

In [None]:
for i in range(len(up_layers)):
    out_tmp = inception_layers[i](out_tmp)
    print(out_tmp.shape)
    if use_sam:
        if ret_sam:
            atten, out_tmp = sam_layers[i](out_tmp, ret_att=True)
        else:
            out_tmp = sam_layers[i](out_tmp)

    out_tmp = torch.cat([out_tmp] + cvar_outs, 1)  # concat cvars

    if use_cam:
        out_tmp = cam_layers[i](out_tmp)  # apply channel attention

    out_tmp = up_layers[i](out_tmp)
    print(out_tmp.shape)


In [None]:
if elev is not None:
    ele_tmp = elev
    for layer in ele_layers:
        ele_tmp = layer(ele_tmp)
    out_tmp = torch.cat([out_tmp, ele_tmp], 1)

out_tmp = out_layers(out_tmp)
