In [8]:
import torch
import torch.nn as nn
import numpy as np
import os
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
import torchvision

In [3]:
# Using weights from https://github.com/mseitzer/srgan/


gen_weights = torch.load("./weights/srgan_generator.pth")




In [4]:

# Generator
    # Input of 3 x 96 x 96 (C x W x H)
    # ***
    # N Res-Blocks w/ skip-connections
        # Block:
            # Conv -> BN -> PReLU -> Conv -> BN -> Elementwise Sum (skip connection)

    # Conv -> BN -> ElementWise Sum (skip connection from ***/beginning of block stack)

    # Upsampler
        # 2x (Conv -> PixelShuffler? -> PixelShuffler -> PReLU)
        # Conv

class Generator(nn.Module):
    def __init__(self, n_channels, n_blocks, n_features, n_upsample_blocks):
        super (Generator, self).__init__()
        def _block(n_channels):
            return nn.Sequential(
                nn.Conv2d(n_channels, n_channels, kernel_size=3, stride=1, padding=1,bias=False),
                nn.BatchNorm2d(n_channels, track_running_stats=True),
                nn.PReLU(num_parameters=n_features),
                nn.Conv2d(n_channels, n_channels, kernel_size=3, stride=1,padding=1,bias=False),
                nn.BatchNorm2d(n_channels, track_running_stats=True),
            )

        def upsample(n_features_in, n_features_boost):
            return nn.Sequential(
                nn.Conv2d(n_features_in, n_features_boost, kernel_size=3,stride=1, padding=1),
                nn.PixelShuffle(2),
                nn.PReLU(num_parameters=n_features_boost//4)
            )


        self.num_blocks = n_blocks
        self.num_features = n_features
        self.num_upsample_blocks = n_upsample_blocks



        self.in_conv = nn.Sequential(
            nn.Conv2d(n_channels, n_features, kernel_size=9, stride=1, padding=4,bias=False),
            nn.PReLU(num_parameters=n_features)
        )

        self.blocks = nn.Sequential(*[_block(64) for _ in range(n_blocks)])

        self.out_conv = nn.Sequential(
            nn.Conv2d(n_features, n_features, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(n_features)
        )

        self.upsample = nn.Sequential(
            upsample(n_features, n_features * 16),
            upsample(n_features*4, n_features*16)
        )

        self.out = nn.Conv2d(256, n_channels, kernel_size=9, stride=1, padding=4)


    def forward(self, X):
        X_first = self.in_conv(X)
        X_past = X_first

        for block in self.blocks:
            X_current = block(X_past)
            X_past = X_current + X_past
        
        X_out = self.out_conv(X_past)
        X_out = X_out + X_first




        return self.out(self.upsample(X_out))





    




In [5]:
# c1 = nn.Conv2d(3, 64, 9, 1, padding=4)
# X = torch.randn(1,3,24,24)
# c2 = nn.Conv2d(64,64,3,1,1)

# X = c1(X)
# print(X.shape)
# X = c2(X)
# print(X.shape)

c1 = nn.Conv2d(64, 1024, 3, 1, 1)
shuff = nn.PixelShuffle(2)
c2 = nn.Conv2d(256, 64,3,1,1)


X = torch.randn(1, 64, 24,24)
X = c1(X)
print(X.shape)
X = shuff(X)
print(X.shape)
X = c2(X)
print(X.shape)
X = shuff(X)
print(X.shape)

torch.Size([1, 1024, 24, 24])
torch.Size([1, 256, 48, 48])
torch.Size([1, 64, 48, 48])
torch.Size([1, 16, 96, 96])


In [6]:
gen = Generator(3,16,64,2)


In [7]:
pretrained_weights = list(gen_weights.items())


idx = 0

my_model = gen.state_dict()

for k,v in my_model.items():
    if "num_batches_tracked" not in k:
        layer_name, weights = pretrained_weights[idx]
        my_model[k] = weights
        idx+=1


