In [None]:
from collections import OrderedDict
import torch.nn as nn
import torch.nn.functional as F
import torch.nn as nn
import torch.nn.functional as F
import torch._utils
from torch.nn import init
import numpy as np
import torch.optim
import os
from sys import path
path.append(r"C:\Users\Gautam.Mathur")
from carafe import CARAFE

# From: https://arxiv.org/pdf/2107.00782v2.pdf
class PSA_p(nn.Module):

    def __init__(self, channel=512):
        super().__init__()
        self.ch_wv=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
        self.ch_wq=nn.Conv2d(channel,1,kernel_size=(1,1))
        self.softmax_channel=nn.Softmax(1)
        self.softmax_spatial=nn.Softmax(-1)
        self.ch_wz=nn.Conv2d(channel//2,channel,kernel_size=(1,1))
        self.ln=nn.LayerNorm(channel)
        self.sigmoid=nn.Sigmoid()
        self.sp_wv=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
        self.sp_wq=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
        self.agp=nn.AdaptiveAvgPool2d((1,1))

    def forward(self, x):
        b, c, h, w = x.size()

        #Channel-only Self-Attention
        channel_wv=self.ch_wv(x) #bs,c//2,h,w
        channel_wq=self.ch_wq(x) #bs,1,h,w
        channel_wv=channel_wv.reshape(b,c//2,-1) #bs,c//2,h*w
        channel_wq=channel_wq.reshape(b,-1,1) #bs,h*w,1
        channel_wq=self.softmax_channel(channel_wq)
        channel_wz=torch.matmul(channel_wv,channel_wq).unsqueeze(-1) #bs,c//2,1,1
        channel_weight=self.sigmoid(self.ln(self.ch_wz(channel_wz).reshape(b,c,1).permute(0,2,1))).permute(0,2,1).reshape(b,c,1,1) #bs,c,1,1
        channel_out=channel_weight*x

        #Spatial-only Self-Attention
        spatial_wv=self.sp_wv(x) #bs,c//2,h,w
        spatial_wq=self.sp_wq(x) #bs,c//2,h,w
        spatial_wq=self.agp(spatial_wq) #bs,c//2,1,1
        spatial_wv=spatial_wv.reshape(b,c//2,-1) #bs,c//2,h*w
        spatial_wq=spatial_wq.permute(0,2,3,1).reshape(b,1,c//2) #bs,1,c//2
        spatial_wq=self.softmax_spatial(spatial_wq)
        spatial_wz=torch.matmul(spatial_wq,spatial_wv) #bs,1,h*w
        spatial_weight=self.sigmoid(spatial_wz.reshape(b,1,h,w)) #bs,1,h,w
        spatial_out=spatial_weight*x
        out=spatial_out+channel_out
        return out

class PFC(nn.Module):
    def __init__(self,channels, kernel_size=7):
        super(PFC, self).__init__()
        self.depthwise = nn.Sequential(
                    nn.Conv2d(32, channels, kernel_size,
                              groups=channels, padding= kernel_size // 2, bias = False),
                    nn.BatchNorm2d(channels))
        self.pointwise = nn.Sequential(
                    nn.Conv2d(32, channels, kernel_size=1, bias = False),
                    nn.BatchNorm2d(channels),
                    nn.Mish(inplace=True))
        self.act = nn.Mish()
    def forward(self, x):
        residual = x
        x = self.depthwise(x)
        x += residual
        x = self.act(x)
        x = self.pointwise(x)
        return x

def _make_pair(value):
    if isinstance(value, int):
        value = (value,) * 2
    return value


def conv_layer(in_channels,
               out_channels,
               kernel_size,
               bias=True):
    """
    Re-write convolution layer for adaptive `padding`.
    """
    kernel_size = _make_pair(kernel_size)
    padding = (int((kernel_size[0] - 1) / 2), 
               int((kernel_size[1] - 1) / 2))
    return nn.Conv2d(in_channels,
                     out_channels,
                     kernel_size,
                     padding=padding,
                     bias=bias)


def sequential(*args):
    """
    Modules will be added to the a Sequential Container in the order they
    are passed.
    
    Parameters
    ----------
    args: Definition of Modules in order.
    -------
    """
    if len(args) == 1:
        if isinstance(args[0], OrderedDict):
            raise NotImplementedError(
                'sequential does not support OrderedDict input.')
        return args[0]
    modules = []
    for module in args:
        if isinstance(module, nn.Sequential):
            for submodule in module.children():
                modules.append(submodule)
        elif isinstance(module, nn.Module):
            modules.append(module)
    return nn.Sequential(*modules)


class Conv(nn.Module):
    # The basic building block follows conv bn relu, with activation after bn
    def __init__(self, nIn, nOut, kSize, stride, padding, dilation=(1, 1), groups=1, bn_acti=False, bias=False):
        super().__init__()
        
        self.bn_acti = bn_acti
        
        self.conv = nn.Conv2d(nIn, nOut, kernel_size = kSize,
                              stride=stride, padding=padding,
                              dilation=dilation,groups=groups,bias=bias)
        
        if self.bn_acti:
            self.bn_relu = BNPReLU(nOut)
            
    def forward(self, input):
        output = self.conv(input)

        if self.bn_acti:
            output = self.bn_relu(output)

        return output  
    
    
class BNPReLU(nn.Module):
    def __init__(self, nIn):
        super().__init__()
        self.bn = nn.BatchNorm2d(nIn)
        self.acti = nn.Mish(nIn)

    def forward(self, input):
        output = self.bn(input)
        output = self.acti(output)
        
        return output


class self_attn(nn.Module):
    def __init__(self, in_channels, mode='hw'):
        super(self_attn, self).__init__()

        self.mode = mode

        self.query_conv = Conv(in_channels, in_channels // 8, kSize=(1, 1),stride=1,padding=0)
        self.key_conv = Conv(in_channels, in_channels // 8, kSize=(1, 1),stride=1,padding=0)
        self.value_conv = Conv(in_channels, in_channels, kSize=(1, 1),stride=1,padding=0)

        self.gamma = nn.Parameter(torch.zeros(1))
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        batch_size, channel, height, width = x.size()

        axis = 1
        if 'h' in self.mode:
            axis *= height
        if 'w' in self.mode:
            axis *= width

        view = (batch_size, -1, axis)

        projected_query = self.query_conv(x).view(*view).permute(0, 2, 1)
        projected_key = self.key_conv(x).view(*view)

        attention_map = torch.bmm(projected_query, projected_key)
        attention = self.sigmoid(attention_map)
        projected_value = self.value_conv(x).view(*view)

        out = torch.bmm(projected_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, channel, height, width)

        out = self.gamma * out + x
        return out

# Axial attention from: https://arxiv.org/pdf/1912.12180.pdf
class AA_kernel(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(AA_kernel, self).__init__()
        self.conv0 = Conv(in_channel, out_channel, kSize=1,stride=1,padding=0)
        self.conv1 = Conv(out_channel, out_channel, kSize=(3, 3),stride = 1, padding=1)
        self.Hattn = self_attn(out_channel, mode='h')
        self.Wattn = self_attn(out_channel, mode='w')

    def forward(self, x):
        x = self.conv0(x)
        x = self.conv1(x)

        Hx = self.Hattn(x)
        Wx = self.Wattn(Hx)

        return Wx

# Pixelshuffle block from: https://arxiv.org/pdf/1609.05158v2.pdf
def pixelshuffle_block(in_channels,
                       out_channels,
                       upscale_factor=2,
                       kernel_size=3):
    """
    Upsample features according to `upscale_factor`.
    """
    conv = conv_layer(in_channels,
                      out_channels * (upscale_factor ** 2),
                      kernel_size, bias = False)
    pixel_shuffle = nn.PixelShuffle(upscale_factor)
    return sequential(conv, pixel_shuffle)

    
class ResidualPSA(nn.Module):
    """
    Residual block with PSA attention
    """

    def __init__(self,
                 in_channels,
                 out_channels):
        super(ResidualPSA, self).__init__()

        #if mid_channels is None:
        #    mid_channels = in_channels
        if out_channels is None:
            out_channels = in_channels

        #self.c1_r = Conv2dWeightNorm(in_channels, in_channels, 3)
        self.c1_r = conv_layer(in_channels, out_channels, 3, bias = False)
        self.act = nn.Mish()
        #self.c2_r = Conv2dWeightNorm(in_channels, in_channels, 3)
        self.c2_r = conv_layer(out_channels, out_channels, 3, bias = False)
        self.norm_layer = nn.BatchNorm2d(out_channels) #nn.GroupNorm(4, 32)#nn.BatchNorm2d(32)
        self.norm_layer1 = nn.BatchNorm2d(out_channels)#nn.GroupNorm(4, 32)#nn.BatchNorm2d(32)
        self.psa = PSA_p(out_channels)
        #self.psa2 = PSA_p(out_channels, out_channels)
        #self.splat = SplAtConv2d(out_channels, out_channels, 3)
        

        self.act = nn.Mish()

    def forward(self, x):
        out = (self.c1_r(x))
        out = self.norm_layer(out)
        out = self.act(out)
        #out = self.splat(out)
        out = self.psa(out)
        

        out = (self.c2_r(out))
        out = self.norm_layer1(out)
        #out = self.psa(out)
        
        #print(out.shape, x.shape)
        out = out + x
        #if self.final_relu:
        out = self.act(out)
        #out = self.psa2(out)

        return out


def constant_init(module, val, bias=0):
    if hasattr(module, 'weight') and module.weight is not None:
        nn.init.constant_(module.weight, val)
    if hasattr(module, 'bias') and module.bias is not None:
        nn.init.constant_(module.bias, bias)


def kaiming_init(module,
                 a=0,
                 mode='fan_out',
                 nonlinearity='relu',
                 bias=0,
                 distribution='normal'):
    assert distribution in ['uniform', 'normal']
    if distribution == 'uniform':
        nn.init.kaiming_uniform_(
            module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
    else:
        nn.init.kaiming_normal_(
            module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
    if hasattr(module, 'bias') and module.bias is not None:
        nn.init.constant_(module.bias, bias)
        

    
class CFPModule(nn.Module):
    def __init__(self, nIn, d=1, KSize=3,dkSize=3):
        super().__init__()
        
        self.bn_relu_1 = BNPReLU(nIn)
        self.bn_relu_2 = BNPReLU(nIn)
        self.conv1x1_1 = Conv(nIn, nIn // 4, KSize, 1, padding=1, bn_acti=True)
        
        self.dconv_4_1 = Conv(nIn //4, nIn //16, (dkSize,dkSize),1,padding = (1*d+1,1*d+1),
                            dilation=(d+1,d+1), groups = nIn //16, bn_acti=True)
        
        self.dconv_4_2 = Conv(nIn //16, nIn //16, (dkSize,dkSize),1,padding = (1*d+1,1*d+1),
                            dilation=(d+1,d+1), groups = nIn //16, bn_acti=True)
        
        self.dconv_4_3 = Conv(nIn //16, nIn //8, (dkSize,dkSize),1,padding = (1*d+1,1*d+1),
                            dilation=(d+1,d+1), groups = nIn //16, bn_acti=True)
        
        
        
        self.dconv_1_1 = Conv(nIn //4, nIn //16, (dkSize,dkSize),1,padding = (1,1),
                            dilation=(1,1), groups = nIn //16, bn_acti=True)
        
        self.dconv_1_2 = Conv(nIn //16, nIn //16, (dkSize,dkSize),1,padding = (1,1),
                            dilation=(1,1), groups = nIn //16, bn_acti=True)
        
        self.dconv_1_3 = Conv(nIn //16, nIn //8, (dkSize,dkSize),1,padding = (1,1),
                            dilation=(1,1), groups = nIn //16, bn_acti=True)
        
        
        
        self.dconv_2_1 = Conv(nIn //4, nIn //16, (dkSize,dkSize),1,padding = (int(d/4+1),int(d/4+1)),
                            dilation=(int(d/4+1),int(d/4+1)), groups = nIn //16, bn_acti=True)
        
        self.dconv_2_2 = Conv(nIn //16, nIn //16, (dkSize,dkSize),1,padding = (int(d/4+1),int(d/4+1)),
                            dilation=(int(d/4+1),int(d/4+1)), groups = nIn //16, bn_acti=True)
        
        self.dconv_2_3 = Conv(nIn //16, nIn //8, (dkSize,dkSize),1,padding = (int(d/4+1),int(d/4+1)),
                            dilation=(int(d/4+1),int(d/4+1)), groups = nIn //16, bn_acti=True)
        
        
        self.dconv_3_1 = Conv(nIn //4, nIn //16, (dkSize,dkSize),1,padding = (int(d/2+1),int(d/2+1)),
                            dilation=(int(d/2+1),int(d/2+1)), groups = nIn //16, bn_acti=True)
        
        self.dconv_3_2 = Conv(nIn //16, nIn //16, (dkSize,dkSize),1,padding = (int(d/2+1),int(d/2+1)),
                            dilation=(int(d/2+1),int(d/2+1)), groups = nIn //16, bn_acti=True)
        
        self.dconv_3_3 = Conv(nIn //16, nIn //8, (dkSize,dkSize),1,padding = (int(d/2+1),int(d/2+1)),
                            dilation=(int(d/2+1),int(d/2+1)), groups = nIn //16, bn_acti=True)
        
                      
        
        self.conv1x1 = Conv(nIn, nIn, 1, 1, padding=0,bn_acti=False)  
        
    def forward(self, input):
        inp = self.bn_relu_1(input)
        inp = self.conv1x1_1(inp)
        
        o1_1 = self.dconv_1_1(inp)
        o1_2 = self.dconv_1_2(o1_1)
        o1_3 = self.dconv_1_3(o1_2)
        
        o2_1 = self.dconv_2_1(inp)
        o2_2 = self.dconv_2_2(o2_1)
        o2_3 = self.dconv_2_3(o2_2)
        
        o3_1 = self.dconv_3_1(inp)
        o3_2 = self.dconv_3_2(o3_1)
        o3_3 = self.dconv_3_3(o3_2)
        
        o4_1 = self.dconv_4_1(inp)
        o4_2 = self.dconv_4_2(o4_1)
        o4_3 = self.dconv_4_3(o4_2)
        
        output_1 = torch.cat([o1_1,o1_2,o1_3], 1)
        output_2 = torch.cat([o2_1,o2_2,o2_3], 1)      
        output_3 = torch.cat([o3_1,o3_2,o3_3], 1)       
        output_4 = torch.cat([o4_1,o4_2,o4_3], 1)   
        
        
        ad1 = output_1
        ad2 = ad1 + output_2
        ad3 = ad2 + output_3
        ad4 = ad3 + output_4
        output = torch.cat([ad1,ad2,ad3,ad4],1)
        output = self.bn_relu_2(output)
        output = self.conv1x1(output)
        
        return output+input

In [None]:
    
class HRSupResNet(nn.Module):
    """
    High resolution super resolution network. This approach basically flips the
    architecture of HRNet: https://arxiv.org/pdf/1904.04514.pdf
    And uses PSA attention in the encoder with CARAFE to upsample
    
    Carafe: https://openaccess.thecvf.com/content_ICCV_2019/papers/Wang_CARAFE_Content-Aware_ReAssembly_of_FEatures_ICCV_2019_paper.pdf
    """

    def __init__(self,
                 in_channels=10,
                 out_channels=1,
                 feature_channels=32,
                 upscale=2,
                 tenmfeats = 32):
        super(HRSupResNet, self).__init__()
        
        self.lrelu = nn.Mish()
        
        # 1x -> 2x - this block downsamples the 10m input to 5m
        self.feats = nn.Conv2d(in_channels, 32, 1, 1, bias = True)
        self.upsampler = CARAFE(32)
        self.PSA_p1 = PSA_p(32)
        self.PSA_p2 = PSA_p(16)
        
        # 2x residual blocks
        self.block2xb = ResidualPSA(32, 32)
        self.block2xc = ResidualPSA(32, 32)
        
        # 1x residual blocks
        self.block1xa = ResidualPSA(tenmfeats, tenmfeats)
        self.block1xb = ResidualPSA(tenmfeats, tenmfeats)
        self.block1xc = ResidualPSA(tenmfeats, tenmfeats)
                
        # 2x -> 1x downsample skip connect (2xa - 1xb)
        self.downsampleconva = nn.Conv2d(32, 32, 3, 2, 1, bias = False)
        self.norm_layer_adown = nn.BatchNorm2d(32)
        self.downsampleconv1x1a = nn.Conv2d(32+tenmfeats, tenmfeats, 1, bias = False)
        self.norm_layer_adown2 = nn.BatchNorm2d(tenmfeats)
        
        # 2x -> 1x downsample skip connect (2xb - 1xc)
        self.downsampleconvb = nn.Conv2d(32, 32, 3, 2, 1, bias = False)
        self.norm_layer_bdown = nn.BatchNorm2d(32)
        self.downsampleconv1x1b = nn.Conv2d(32+tenmfeats, tenmfeats, 1, bias = False)
        self.norm_layer_bdown2 = nn.BatchNorm2d(tenmfeats)
        
        # 1x -> 2x upsample skip connect
        #(1xa - 2xb)
        self.upsample1x1conva = nn.Conv2d(32+tenmfeats, 32, 1, bias = False)
        self.norm_layer_aup = nn.BatchNorm2d(32)
        
        #(1xb - 2xc)
        self.upsample1x1convb = nn.Conv2d(32+tenmfeats, 32, 1, bias = False)
        self.norm_layer_bup = nn.BatchNorm2d(32)
        
        #(1xc - 2xd)
        self.upsample1x1convc = nn.Conv2d(32+tenmfeats, 32, 1, bias = False)
        self.norm_layer_cup = nn.BatchNorm2d(32)
        
        # 2 - > 4
        self.downsamplefeatsb = nn.Conv2d(32, 16, 1, 1, bias = True)
        self.upsamplerb = CARAFE(16)
        self.block4xa = ResidualPSA(16, 16)
        self.upsample1x1convd = nn.Conv2d(32+16, 32, 1, bias = False)
        self.norm_layer_dup = nn.BatchNorm2d(32)
        
        self.CFP_1 = CFPModule(32, d = 8)
        self.CFP_2 = CFPModule(32, d = 8)
        self.CFP_3 = CFPModule(32, d = 8)

        self.ra1_conv1 = Conv(32, 32,3,1,padding=1,bn_acti=True)
        self.ra1_conv3 = Conv(32,1,3,1,padding=1,bn_acti=True)
        
        self.ra2_conv1 = Conv(32,32,3,1,padding=1,bn_acti=True)
        self.ra2_conv3 = Conv(32,1,3,1,padding=1,bn_acti=True)
        
        self.ra3_conv1 = Conv(32,32,3,1,padding=1,bn_acti=True)
        self.ra3_conv3 = Conv(32,1,3,1,padding=1,bn_acti=True)
        
        self.aa_kernel_1 = AA_kernel(32, 32)
        self.aa_kernel_2 = AA_kernel(32,32)
        self.aa_kernel_3 = AA_kernel(32,32)


        # Classifier
        self.conv_last = nn.Conv2d(32, 1, 1, 1, bias = True)

    def forward(self, x):
        # Residual PSA block the extracted features
        fea1xa = self.block1xa(self.lrelu(self.feats(x)))
        
        # Upsample 1x -> 2x
        # The output of this is attention features on the 5m feature extractions
        fea2xa = self.upsampler(fea1xa)
        fea2xa = self.PSA_p1(fea2xa)
        
        # B Blocks
        # residual blocks on the 5-m and 10-m feats
        fea2xb = self.block2xb(fea2xa)
        fea1xb = self.block1xb(fea1xa)
        
        # C blocks
        # (1xa -> 2xb)
        # Cross-scale feature fusion
        fea1xup = F.interpolate(fea1xb, scale_factor=2, mode='bilinear')
        fea1xup = torch.concat([fea1xup, fea2xb], axis = 1)
        fea1xupout_b = self.norm_layer_aup(self.upsample1x1conva(fea1xup))
        # (2xa -> 1xb)
        fea2xbdown = self.lrelu(self.norm_layer_adown(self.downsampleconva(fea2xb)))
        fea2xbdown = torch.concat([fea2xbdown, fea1xb], axis = 1)
        fea2xdownout_b = self.norm_layer_adown2(self.downsampleconv1x1a(fea2xbdown))
        
        fea2xc = self.block2xb(self.lrelu(fea1xupout_b))
        fea1xc = self.block1xc(self.lrelu(fea2xdownout_b))
    
                               
        # D blocks
        # (1xb -> 2xc)
        fea1xupc = F.interpolate(fea1xc, scale_factor=2, mode='bilinear')
        fea1xupc = torch.concat([fea1xupc, fea2xc], axis = 1)
        fea1xupout_c = self.norm_layer_cup(self.upsample1x1convc(fea1xupc))
        
        fea2xd = (self.lrelu(fea1xupout_c))#self.block2xd(self.lrelu(fea1xupout_c))
        
        # 2 - > 4
        fea4xa = self.lrelu(self.downsamplefeatsb(fea2xd))
        # Here we make 2.5 feats
        fea4xa = self.upsamplerb(fea4xa)
        fea4xa = self.PSA_p2(fea4xa)
        fea4xa = self.block4xa(fea4xa)
        fea4xup = F.interpolate(fea2xd, scale_factor=2, mode='bilinear')
        fea4xa = torch.concat([fea4xa, fea4xup], axis = 1)
        fea4xa = self.lrelu(self.norm_layer_dup(self.upsample1x1convd(fea4xa)))


        # Axial attn
        decoder_1 = self.conv_last(fea4xa)
        
        decoder_2 = F.interpolate(decoder_1, scale_factor=0.25, mode='bilinear')
        cfp_out_1 = self.CFP_3(fea1xc) # 32 - 32
        decoder_2_ra = -1*(torch.sigmoid(decoder_2)) + 1
        aa_atten_3 = self.aa_kernel_3(cfp_out_1)
        aa_atten_3_o = decoder_2_ra.expand(-1, 32, -1, -1).mul(aa_atten_3)
        
        
        ra_3 = self.ra3_conv1(aa_atten_3_o) # 32 - 32
        #ra_3 = self.ra3_conv2(ra_3) # 32 - 32
        ra_3 = self.ra3_conv3(ra_3) # 32 - 1
        x_3 = ra_3 + decoder_2 # 10m prediction here
        #10m deep super
        tenm_out = F.interpolate(x_3, scale_factor=4, mode='bilinear')
        
        decoder_3 = F.interpolate(x_3, scale_factor=2, mode='bilinear')
        cfp_out_2 = self.CFP_2(fea2xc) # 32 - 32
        decoder_3_ra = -1*(torch.sigmoid(decoder_3)) + 1
        aa_atten_2 = self.aa_kernel_2(cfp_out_2)
        aa_atten_2_o = decoder_3_ra.expand(-1, 32, -1, -1).mul(aa_atten_2)
        
        ra_2 = self.ra2_conv1(aa_atten_2_o) # 32 - 32
        #ra_2 = self.ra2_conv2(ra_2) # 32 - 32
        ra_2 = self.ra2_conv3(ra_2) # 32 - 1
        
        x_2 = ra_2 + decoder_3 # 5m prediction here
        fivem_out = F.interpolate(x_2, scale_factor=2, mode='bilinear')
        
        decoder_4 = F.interpolate(x_2, scale_factor=2, mode='bilinear')
        cfp_out_3 = self.CFP_1(fea4xa) # 32 - 32
        decoder_4_ra = -1*(torch.sigmoid(decoder_4)) + 1
        aa_atten_1 = self.aa_kernel_1(cfp_out_3)
        aa_atten_1_o = decoder_4_ra.expand(-1, 32, -1, -1).mul(aa_atten_1)
        
        ra_1 = self.ra1_conv1(aa_atten_1_o) # 32 - 32
        ra_1 = self.ra1_conv3(ra_1) # 32 - 1
        
        x_1 = ra_1 + decoder_4

        out = x_1
    
        return out, decoder_1, tenm_out, fivem_out

In [None]:
# Here is where we actually make the model
model = HRSupResNet()

In [None]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

print(get_n_params(model))

In [None]:
criterion = nn.MSELoss() # Mean squared error between prediction and label

# Set optimizers, learning rate scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer2 = torch.optim.SGD(model.parameters(), lr = 0.1, momentum = 0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1, verbose = True)

# Normalization #TODO - Gautam
# CNN expects input to be centered around 0
# You need to find the min, max per-band
# And store it to pass to make_input_data

In [None]:
import rasterio
import os
import numpy as np
import random

os.chdir(r"C:\Users\Gautam.Mathur\OneDrive - World Resources Institute\Gautam Intern Materials\All_Images_6\resampled_results")
images = os.listdir()
images = [i for i in images if i.endswith(".tif")]
mins = []
maxes = []

for idx in range(5290):
    try:
        imgname = str(idx) + "final.tif"
        img= rasterio.open(imgname).read().astype(np.float32)
        maxval = []
        minval = []
        midr = []
        for i in range(12):
            m = np.max(img[i,...])
            maxval.append(m)
            n = np.min(img[i,...])
            minval.append(n)
        maxes.append(maxval)
        mins.append(minval)
        
    except:
        print(str(idx) + " doesn't exist")

maxperband = []
for i in range(12):
    localmaxes = []
    for j in maxes:
        localmaxes.append(j[i])
    maxperband.append(max(localmaxes))



minperband = []
for i in range(12):
    localmins = []
    for j in mins:
        localmins.append(j[i])
    minperband.append(max(localmins))


    

#midrange = (maxs[idx] + mins[idx]) / 2
#rng = maxs[idx] - mins[idx]
#arr[..., idx] = (arr[..., idx] - midrange[idx]) / (rng / 2)

In [None]:
def make_input_data(files, maxes = maxperband, mins = minperband,
                    batch_size = 8, size = 14, scale = 4, bandsin = 12, bandsout = 3):
    
    # Make the batch as an empty array
    x_batch = np.zeros((batch_size, bandsin, size, size))
    y_batch = np.zeros((batch_size, size*scale, size*scale, bandsout))
    
    # files should be a randomly selected list of batch_size files
    
    for i in range(batch_size):
        imgfolder = r"C:\Users\Gautam.Mathur\All_Images_6"
        xpath = os.path.join(imgfolder, "resampled_results", str(files[i]) + "final"+  ".tif")
        ypath = os.path.join(imgfolder, "wv"+ str(files[i]) + ".tif")
        x_batch[i] = rasterio.open(xpath).read().astype(np.float32)
        y_batch[i] = rasterio.open(ypath).read().astype(np.float32)

    for idx in range(0, bandsin):
        midrange = (maxes[idx] + mins[idx]) / 2
        rng = maxes[idx] - mins[idx]
        x_batch[idx, ...] = (x_batch[idx,...] - midrange) / (rng / 2)
    return x_batch, y_batch



In [None]:
def save_model(model, optimizer, path, name):
    output_folder = f"{path}{name}"
    if not os.path.exists(os.path.realpath(output_folder)):
            os.makedirs(os.path.realpath(output_folder))
    torch.save({
            'epoch': name,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, f"{output_folder}/model")

In [None]:
def make_preds(logits):
    logits = F.sigmoid(logits)
    l_pred = torch.cat([1-logits, logits], axis = 1)
    return l_pred
#def batches(idxlist, outputlist, batch_size=8):
        #if len(idxlist)>=batch_size:
            #random.shuffle(idxlist)
            #batch = idxlist[:batch_size]
            #outputlist.append(batch)
            #idxlist=idxlist[batch_size:]
            #batches(idxlist, outputlist)

#model = torch.load('models/supres/HRNet_5m_400/model')
model.train()
for epoch in range(1, 200):  # loop over the dataset multiple times

    running_loss = 0.0

    scale = 4
    size = 28
    
    # Write a function here to make a list of batchnames
    batchnames = []
    train_samples = np.arange(0, 5290)
    random.shuffle(train_samples)
    batch_size = 8
    for batch in np.arange(0, len(train_samples), batch_size):
        batch_samples = train_samples[batch:batch + batch_size]
        batchnames.append(batch_samples)
    

    #batches()
    for batch in batchnames: #iterate through the length of training data divided by batch size
        
        inputs, labels, = make_input_data(batch,  
                                           batch_size = 8, 
                                           size = size,
                                           scale = scale)
        

        inputs = torch.tensor(inputs).float()
        labels = torch.tensor(labels).float()


        # forward + backward + optimize
        outputs, tenm_out, fivem_out = model.float()(inputs)

        outputs = make_preds(outputs)
        fivem_out = make_preds(fivem_out)
        tenm_out = make_preds(tenm_out)

        labs = torch.cat([1 - labels[:, np.newaxis, ...], 
                                             labels[:, np.newaxis, ...]], axis = 1)
  
    
        loss =  criterion(outputs, labs)
        loss_5m = criterion(fivem_out, labs)
        loss_tenm = criterion(tenm_out, labs)

        loss = (loss + loss_5m + loss_tenm) / 3

        
        if not torch.isnan(loss):
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        else:
            print("nan loss")
            
    if epoch % 2 == 0:
        save_model(model, optimizer,
                   path = 'models/supres/folder_name', # Make a folder name
                   name = str(epoch).zfill(3))

    print(f'[{epoch + 1}] loss: {loss.item()}')
    running_loss = 0.0

print('Finished Training')
#torch.save(model, "models/supres/HRNet_5m_1000_carafe_PSA_backbone_gru/model")

In [None]:
save_model(model, optimizer,
                   path = 'models/supres/HRNet_4x_carafe_master-',
                   name = str(epoch).zfill(3))

In [None]:
checkpoint = torch.load('models/supres/HRNet_4x_carafe_master-559/model')

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
#INPUT
#- resamples result from coregistration (10 m, all 10 bands)
#- resampled 2.5 m WVimage

#-model iteratively gets trained on random samples of size 8, 16, or 32
#- output should be scaled from 0 to 1