In [2]:
import torch
import torchvision
from torch import nn

In [5]:
model=torchvision.models.vgg19()
model

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padd

In [12]:
model.features(torch.rand(1,3,1920,1080)).shape


torch.Size([1, 512, 60, 33])

In [None]:
class FastSRGAN(nn.Module):
    def __init__(self,args):
        
        self.hr_height = args.hr_size
        self.hr_width = args.hr_size
        self.lr_height = self.hr_height // 4  # Low resolution height
        self.lr_width = self.hr_width // 4  # Low resolution width
        self.lr_shape = (self.lr_height, self.lr_width, 3)
        self.hr_shape = (self.hr_height, self.hr_width, 3)
        self.iterations = 0

        # Number of inverted residual blocks in the mobilenet generator
        self.n_residual_blocks = 6

        '''
        self.gen_schedule = keras.optimizers.schedules.ExponentialDecay(
            args.lr,
            decay_steps=100000,
            decay_rate=0.1,
            staircase=True
        )

        self.disc_schedule = keras.optimizers.schedules.ExponentialDecay(
            args.lr * 5,  # TTUR - Two Time Scale Updates
            decay_steps=100000,
            decay_rate=0.1,
            staircase=True
        )
        '''


        self.gen_optimizer = torch.optim.Adam(lr=args.lr)
        self.disc_optimizer = torch.optim.Adam(lr=args.lr*5)


        self.vgg = self.build_vgg()
        
        
        patch = int(self.hr_height / 2 ** 4)
        self.disc_patch = (patch, patch, 1)

        # Number of filters in the first layer of G and D
        self.gf = 32  # Realtime Image Enhancement GAN Galteri et al.
        self.df = 32

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()

        # Build and compile the generator for pretraining.
        self.generator = self.build_generator()


        '''
        Custom Methods
        '''

        self.transforms = torchvision.models.VGG19_Weights.IMAGENET1K_V1.transforms
        self.MSEVGG_Loss = nn.MSELoss()

    
    def content_loss(self, hr, sr):
        transforms = 
        
        #sr = keras.applications.vgg19.preprocess_input(((sr + 1.0) * 255) / 2.0)
        #hr = keras.applications.vgg19.preprocess_input(((hr + 1.0) * 255) / 2.0)
        sr = self.transforms(sr)
        hr = self.transforms(hr)
        #sr_features = self.vgg(sr) / 12.75
        #hr_features = self.vgg(hr) / 12.75
        sr_features = self.vgg(sr) / 12.75
        hr_features = self.vgg(hr) / 12.75
        
        return self.MSEVGG_Loss(hr_features, sr_features)



    def build_vgg(self):
        # input_shape=self.hr_shape
        
        vgg = torchvision.models.vgg19(weights='DEFAULT').features

        for param in vgg.parameters():
            param.requires_grad=False

        return vgg

    def build_generator(self):
        
        def _make_divisible(v, divisor, min_value=None):
                if min_value is None:
                    min_value = divisor
                new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
                # Make sure that round down does not go down by more than 10%.
                if new_v < 0.9 * v:
                    new_v += divisor
                return new_v
            
        def residual_block(inputs, filters, block_id, expansion=6, stride=1, alpha=1.0):
            channel_axis = 1 # By default in PyTorch the channel axis is dim=1

            in_channels = inputs.shape[1]
            pointwise_conv_filters = int(filters * alpha)
            pointwise_filters = _make_divisible(pointwise_conv_filters, 8)
            x = inputs
            prefix = 'block_{}_'.format(block_id)

            if block_id:

                x = nn.Conv2d(in_channels = in_channels,
                              out_channels = expansion * in_channels,
                              kernel_size=3,
                              stride=stride,
                              bias=True)

