In [1]:
import time
import torch
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
cudnn.benchmark = True

class SuperResBlock(nn.Module):
    """Upsample Volume using subpixel convolution.
    
    Reference: https://arxiv.org/pdf/1609.05158.pdf"""
    def __init__(self, upscale_factor):
        super(SuperResBlock, self).__init__()
        self.activation = nn.ReLU()
        self.dconv1 = nn.Parameter(T.FloatTensor(64,8,5,5))
        self.dbias1 = nn.Parameter(T.FloatTensor(64))
        self.dpad1 = (2,2)
        self.dbn1 = nn.BatchNorm2d(64)
        self.dconv2 = nn.Parameter(T.FloatTensor(64,64,3,3))
        self.dbias2 = nn.Parameter(T.FloatTensor(64))
        self.dpad2 = (1,1)
        self.dbn2 = nn.BatchNorm2d(64)
        self.dconv3 = nn.Parameter(T.FloatTensor(32,64,3,3))
        self.dbias3 = nn.Parameter(T.FloatTensor(32))
        self.dpad3 = (1,1)
        self.dbn3 = nn.BatchNorm2d(32)
        self.dconv4 = nn.Parameter(T.FloatTensor(upscale_factor**2,32,3,3))
        self.dbias4 = nn.Parameter(T.FloatTensor(upscale_factor**2))
        self.dpad4 = (1,1)
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
        
        self.initialize_weights()

    def forward(self, x):        
        x = self.activation(self.dbn1(F.conv2d(x, self.dconv1, self.dbias1, padding=self.dpad1)))
        x = self.activation(self.dbn2(F.conv2d(x, self.dconv2, self.dbias2, padding=self.dpad2)))
        x = self.activation(self.dbn3(F.conv2d(x, self.dconv3, self.dbias3, padding=self.dpad3)))
        x = F.conv2d(x, self.dconv4, self.dbias4, padding=self.dpad4)
        x = self.pixel_shuffle(x)
        # add back single channel
#         x = x[:,:,None]
        return x
    
    def initialize_weights(self):
        nn.init.orthogonal_(self.dconv1, nn.init.calculate_gain('relu'))
        nn.init.orthogonal_(self.dconv2, nn.init.calculate_gain('relu'))
        nn.init.orthogonal_(self.dconv3, nn.init.calculate_gain('relu'))
        nn.init.orthogonal_(self.dconv4)
        for bn in [self.dbn1,self.dbn2,self.dbn3]:
            nn.init.constant_(bn.weight, 1)
            nn.init.constant_(bn.bias, 0)
            
class SuperResBlockNotFunctional(nn.Module):
    """Upsample Volume using subpixel convolution.
    
    Reference: https://arxiv.org/pdf/1609.05158.pdf"""
    def __init__(self, upscale_factor):
        super(SuperResBlockNotFunctional, self).__init__()
        self.activation = nn.ReLU()
        self.dpad1 = (2,2)
        self.dconv1 = nn.Conv2d(8,64,(5,5),padding=self.dpad1)
        self.dbn1 = nn.BatchNorm2d(64)
        self.dpad2 = (1,1)
        self.dconv2 = nn.Conv2d(64,64,(3,3),padding=self.dpad2)
        self.dbn2 = nn.BatchNorm2d(64)
        self.dpad3 = (1,1)
        self.dconv3 = nn.Conv2d(64,32,(3,3),padding=self.dpad3)
        self.dbn3 = nn.BatchNorm2d(32)
        self.dpad4 = (1,1)
        self.dconv4 = nn.Conv2d(32,upscale_factor**2,(3,3),padding=self.dpad4)
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
        
        self.initialize_weights()

    def forward(self, x):        
        x = self.activation(self.dbn1(self.dconv1(x)))
        x = self.activation(self.dbn2(self.dconv2(x)))
        x = self.activation(self.dbn3(self.dconv3(x)))
        x = self.dconv4(x)
        x = self.pixel_shuffle(x)
        return x
    
    def initialize_weights(self):
        nn.init.orthogonal_(self.dconv1.weight, nn.init.calculate_gain('relu'))
        nn.init.orthogonal_(self.dconv2.weight, nn.init.calculate_gain('relu'))
        nn.init.orthogonal_(self.dconv3.weight, nn.init.calculate_gain('relu'))
        nn.init.orthogonal_(self.dconv4.weight)
        for bn in [self.dbn1,self.dbn2,self.dbn3]:
            nn.init.constant_(bn.weight, 1)
            nn.init.constant_(bn.bias, 0)


class tofp16(nn.Module):
    """
    Model wrapper that implements::
        def forward(self, input):
            return input.half()
    """

    def __init__(self):
        super(tofp16, self).__init__()

    def forward(self, input):
        return input.half()


def BN_convert_float(module):
    '''
    Designed to work with network_to_half.
    BatchNorm layers need parameters in single precision.
    Find all layers and convert them back to float. This can't
    be done with built in .apply as that function will apply
    fn to all modules, parameters, and buffers. Thus we wouldn't
    be able to guard the float conversion based on the module type.
    '''
    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
        module.float()
    for child in module.children():
        BN_convert_float(child)
    return module


def network_to_half(network):
    """
    Convert model to half precision in a batchnorm-safe way.
    """
    return nn.Sequential(tofp16(), BN_convert_float(network.half()))
            
input_c = 8
W = 256
H = 256

In [5]:
net = SuperResBlock(2).cuda()
inp = T.randn(64, input_c, 256, 256, requires_grad=True).cuda()

for i in range(5):
    net.zero_grad()
    out = net.forward(inp)
    loss = out.sum()
    loss.backward()

T.cuda.synchronize()
start=time.time()
for i in range(100):
    net.zero_grad()
    out = net.forward(inp)
    loss = out.sum()
    loss.backward()
T.cuda.synchronize()
end=time.time()

print("Functional convolution FP32 Iterations per second: ", 100/(end-start))


net = SuperResBlockNotFunctional(2).cuda()
# inp = T.randn(64, input_c, 256, 256, requires_grad=True).cuda()

for i in range(5):
    net.zero_grad()
    out = net.forward(inp)
    loss = out.sum()
    loss.backward()

T.cuda.synchronize()
start=time.time()
for i in range(100):
    net.zero_grad()
    out = net.forward(inp)
    loss = out.sum()
    loss.backward()
T.cuda.synchronize()
end=time.time()

print("FP32 Iterations per second: ", 100/(end-start))

Functional convolution FP32 Iterations per second:  2.993901670164117
FP32 Iterations per second:  2.998001570069859


In [11]:
inp.shape

torch.Size([4])

In [4]:
net = network_to_half(SuperResBlock(2).cuda())
inp = T.cuda.HalfTensor(64, input_c, 256, 256).normal_()
inp.requires_grad = True

T.cuda.synchronize()
start=time.time()
for i in range(100):
    net.zero_grad()
    out = net.forward(inp)
    loss = out.float().sum()
    loss.backward()
T.cuda.synchronize()
end=time.time()

print("Functional convolution FP16 Iterations per second: ", 100/(end-start))

Functional convolution FP16 Iterations per second:  5.623211761576424


In [12]:
net = network_to_half(SuperResBlockNotFunctional(2).cuda())
inp = T.randn(64, input_c, 256, 256, requires_grad=True).half().cuda()

T.cuda.synchronize()
start=time.time()
for i in range(100):
    net.zero_grad()
    out = net.forward(inp)
    loss = out.float().sum()
    loss.backward()
T.cuda.synchronize()
end=time.time()

print("FP16 Iterations per second: ", 100/(end-start))

Functional convolution FP16 Iterations per second:  5.647740691780046
FP16 Iterations per second:  2.3371556989932443


In [6]:
net = network_to_half(SuperResBlock(2).cuda())
inp = T.cuda.HalfTensor(64, input_c, 255, 255).normal_()
inp.requires_grad = True

T.cuda.synchronize()
start=time.time()
for i in range(100):
    net.zero_grad()
    out = net.forward(inp)
    loss = out.float().sum()
    loss.backward()
T.cuda.synchronize()
end=time.time()

print("Functional convolution FP16 Iterations per second: ", 100/(end-start))


Functional convolution FP16 Iterations per second:  5.236900899320723
