In [2]:
import torch
import torch.nn as nn
import numpy as np
import os
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
import torchvision

  from .autonotebook import tqdm as notebook_tqdm


In [18]:

class ResBlock(nn.Module):
    def __init__(self, n_features):
        super(ResBlock,self).__init__()


        self.layers = nn.Sequential(
            nn.Conv2d(n_features, n_features, 3, 1, 1, bias=False),
            nn.BatchNorm2d(n_features),
            nn.PReLU(n_features),
            nn.Conv2d(n_features, n_features, 3, 1, 1, bias=False),
            nn.BatchNorm2d(n_features),
            nn.PReLU(n_features),
        )

    def forward(self,X ):
        return self.layers(X) + X

class Body(nn.Module):
    def __init__(self, n_features, n_blocks):
        super(Body, self).__init__()

        self.n_features = n_features
        self.n_blocks = n_blocks

        self.blocks = nn.Sequential(*[ResBlock(n_features) for _ in range(n_blocks)])

    def forward(self, X):
        return self.blocks(X)



class Generator(nn.Module):
    
    def __init__(self, device, n_channels, n_features, n_blocks):
        super(Generator, self).__init__()
        def upsample(n_features, n_features_up, up_factor=2):
            return nn.Sequential(
                nn.Conv2d(n_features, n_features_up, 3, 1, 1, bias=False),
                nn.PixelShuffle(up_factor),
                nn.PReLU(n_features_up//4)
        )
        self.device = device
        self.n_channels = n_channels
        self.n_features = n_features
        self.n_blocks = n_blocks

        self.in_conv = nn.Sequential(
            nn.Conv2d(n_channels, n_features, 9, 1, 4,bias=False),
            nn.PReLU(n_features)
        )

        self.body = Body(n_features=n_features, n_blocks=n_blocks)

        self.final_conv = nn.Sequential(
            nn.Conv2d(n_features, n_features, 3,1,1,bias=False),
            nn.BatchNorm2d(n_features)
        )

        self.up_sample = nn.Sequential(
            upsample(n_features, n_features*16),
            upsample(n_features*4, n_features*16)
        )
        self.out = nn.Sequential(
            nn.Conv2d(n_features*4, n_channels, 9,1,4,bias=False),

        )


    
    def forward(self, X):
        X = self.in_conv(X)
        X = self.final_conv(self.body(X)) + X
        X = self.up_sample(X)
        return self.out(X)


class DiscBlock(nn.Module):
    def __init__(self, n_features_in, n_features_out, stride, padding):
        super(DiscBlock, self).__init__()
        
        self.net = nn.Sequential(
            nn.Conv2d(n_features_in, n_features_out, 3, stride=stride,padding=padding,bias=False),
            nn.BatchNorm2d(n_features_out),
            nn.LeakyReLU(0.2)
        )

    def forward(self, X):
        return self.net(X)

class Discriminator(nn.Module):
    def __init__(self, n_channels, n_features):
        super(Discriminator, self).__init__()
        
        self.in_conv = nn.Sequential(
            nn.Conv2d(n_channels, n_features, 3, 1,1,bias=False),
            nn.LeakyReLU(0.2)
        )

        self.layers = nn.Sequential(
            DiscBlock(n_features, n_features*2, 3,1),
            DiscBlock(n_features*2, n_features*2, 3,1),
            DiscBlock(n_features*2, n_features*4, 3,1),
            DiscBlock(n_features*4, n_features*4, 3,1),
            DiscBlock(n_features*4, n_features*8, 3,1),
            DiscBlock(n_features*8, n_features*8, 3,1),
            nn.Flatten(1,-1),
            nn.Linear(n_features*8, n_features*16),
            nn.LeakyReLU(0.2),
            nn.Linear(n_features*16, 1),
            nn.Sigmoid(),
            nn.Flatten(0,-1)
        )

    def forward(self, X):
        X = self.in_conv(X)
        return self.layers(X)

    

In [19]:


import dotenv

f = open(".env", "w")
f.write('PYTORCH_ENABLE_MPS_FALLBACK="1"')
f.close()

dotenv.load_dotenv("../.env", override=True)

device = torch.device("mps")

X = torch.randn(32, 3, 24,24)

gen = Generator(device, 3, 64, 16)
gen = gen.to(device)

disc = Discriminator(3, 64).to(device)


for _ in tqdm(range(10)):
    X = torch.randn(32, 3, 24,24).to(device)
    preds = gen(X)
    print(preds.shape)
    out = disc(preds)
    print(out.shape)
    break



  0%|          | 0/10 [00:00<?, ?it/s]

torch.Size([32, 3, 96, 96])
torch.Size([32])



