In [1]:
import torch
from torch import nn

In [18]:
class CNNblock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(CNNblock, self).__init__()
        self.stride = stride
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride = self.stride, padding_mode='reflect'),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )
    def forward(self, x):
        return self.conv(x)
    

In [41]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features = [64,128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels*2, features[0], kernel_size = 4, stride =2, padding =1, padding_mode = "reflect",),
            nn.LeakyReLU(0.2),
        )
        
        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                CNNblock(in_channels, feature, stride =1 if feature == features[-1] else 2 )
            )
            in_channels = feature
        
        layers.append(
            nn.Conv2d(in_channels, 1, kernel_size = 4, stride =1, padding =1, padding_mode = "reflect")
        )
        
        self.model = nn.Sequential(*layers)
        
    def forward(self, x,y):
        x = torch.cat([x,y], dim =1)
        x = self.initial(x)
        x = self.model(x)
        return x

In [42]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

In [43]:
print(device)

cuda


In [48]:
x  = torch.randn((1,3, 256, 256))
y = torch.randn((1,3, 256, 256))
model = Discriminator(in_channels = 3)

In [49]:
preds = model(x,y)

In [50]:
print(model)

Discriminator(
  (initial): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (model): Sequential(
    (0): CNNblock(
      (conv): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding_mode=reflect)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (1): CNNblock(
      (conv): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding_mode=reflect)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (2): CNNblock(
      (conv): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding_mode=reflect)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       

In [51]:
print(preds.shape)

torch.Size([1, 1, 26, 26])


In [52]:
from torchsummary import summary


In [55]:
summary(model, [(3, 256, 256), (3, 256, 256)])

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same