In [1]:
import torch
import torch.nn as nn

class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv=nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4 , stride = stride, bias =False , padding_mode= 'reflect'),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )
    
    def forward(self,x):
        return self.conv(x)

## Patch discriminator

In [2]:

    
class Discriminator(nn.Module):
    def __init__(self, in_channels=2, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels * 2, features[0], kernel_size=4, stride=2, padding=1, padding_mode='reflect'),
            nn.LeakyReLU(0.2),
        )

        layers = []
        inchannels = features[0]
        for feature in features:
            layers.append(
                CNNBlock(inchannels, feature, stride=1 if feature == features[-1] else 2),
            )
            inchannels = feature  # Update to the current feature size

        # Move the last convolution outside of the loop
        layers.append(
            nn.Conv2d(inchannels, 1, kernel_size=4, stride=1, padding=1, padding_mode='reflect')
        )

        self.layer = nn.Sequential(*layers)

    def forward(self, x, y):
        z = torch.cat([x, y], dim=1)  # Ensure x and y are both of shape [1, 1, 62, 62]
        return self.layer(self.initial(z))


#### test case for patch discriminator

In [3]:
x = torch.randn((1, 3, 256, 256))  # First input (e.g., real image)
y = torch.randn((1, 3, 256, 256)) 
z=torch.cat([x,y],dim=1)
z.shape

torch.Size([1, 6, 256, 256])

In [4]:
torch.manual_seed(42)
def test_discriminator():
    
    x = torch.randn((1, 2, 256, 256)) 
    y = torch.randn((1, 2, 256, 256))  
    
    test_model_discriminator = Discriminator()
    preds = test_model_discriminator(x, y)
    print("Output shape:", preds.shape)
    print("Predictions:", preds)


test_discriminator()


Output shape: torch.Size([1, 1, 10, 10])
Predictions: tensor([[[[ 0.3128, -0.0162, -0.2867,  0.1372, -0.0033,  0.1759,  0.5144,
            0.2245,  0.5993,  0.9960],
          [ 0.5208,  0.5942,  0.2155, -0.1409,  0.5804,  0.0968,  0.5500,
            0.3194, -0.1017,  0.1733],
          [ 0.3673,  0.3999,  0.4392,  0.8167, -0.0501,  0.4871,  0.1982,
            0.6552, -0.5762,  0.5614],
          [ 0.3511,  0.5382, -0.0458,  0.6049, -0.1525,  0.1308,  0.2330,
            0.4913,  1.1680,  0.3298],
          [-0.3581,  0.3010,  0.5054,  0.4582,  0.8617,  0.3152, -0.3276,
            0.1838,  0.3992,  0.0107],
          [-0.2541,  0.5760, -0.1197,  0.1642, -0.1338,  0.9193,  0.3654,
            0.5435,  0.2882,  0.2399],
          [ 0.2375,  0.2074,  0.0512,  0.1200,  0.3323,  0.5569,  0.1875,
            0.1244,  0.2603,  0.6403],
          [ 0.6800,  0.0412, -0.6289,  0.7676,  0.2540, -0.0647,  0.7626,
           -0.3743,  0.9909, -0.1344],
          [ 0.2194,  0.6420,  0.7806,  0.3

## U-Net Genarator  

In [5]:
class  Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act='relu',use_dropout=False):
        super().__init__()
        self.conv=nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4 , stride = 2, padding=1,bias =False , padding_mode= 'reflect')
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4 , stride = 2, padding=1,bias =False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
            )
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
        
    
    def forward(self,x):
        x= self.conv(x)
        return self.dropout(x)if self.use_dropout else x

In [6]:
class UnetGenerator(nn.Module):
    def __init__(self, in_channels, features =64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels,features,4,2,1,padding_mode='reflect'),
            nn.LeakyReLU(0.2)
        )
        self.down1 = Block(features, features*2, down=True, act = "leakyrelu", use_dropout="false")
        self.down2 = Block(features*2, features*4, down=True, act = "leakyrelu", use_dropout="false")
        self.down3 = Block(features*4, features*8, down=True, act = "leakyrelu", use_dropout="false")
        self.down4 = Block(features*8, features*8, down=True, act = "leakyrelu", use_dropout="false")
        self.down5 = Block(features*8, features*8, down=True, act = "leakyrelu", use_dropout="false")
        self.down6 = Block(features*8, features*8, down=True, act = "leakyrelu", use_dropout="false")
        self.bottom_layer = nn.Sequential(
            nn.Conv2d(features*8,features*8,4,2,1,padding_mode="reflect"),
            nn.ReLU()
        )
        self.up1 = Block(features*8, features*8, down=False, act = "relu", use_dropout="True")
        self.up2 = Block(features*8*2, features*8, down=False, act = "relu", use_dropout="True")
        self.up3 = Block(features*8*2, features*8, down=False, act = "relu", use_dropout="True")
        self.up4 = Block(features*8*2, features*8, down=False, act = "relu", use_dropout="false")
        self.up5 = Block(features*8*2, features*4, down=False, act = "relu", use_dropout="false")
        self.up6 = Block(features*4*2, features*2, down=False, act = "relu", use_dropout="false")
        self.up7 = Block(features*2*2, features, down=False ,act = "relu", use_dropout="false")
        self.top_layer = nn.Sequential(
            nn.ConvTranspose2d(features*2,3, kernel_size=4, stride =2 , padding =1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.initial_down(x)  
        d2 = self.down1(d1)        
        d3 = self.down2(d2)        
        d4 = self.down3(d3)        
        d5 = self.down4(d4)        
        d6 = self.down5(d5)    
        d7 = self.down6(d6)    
        last_down = self.bottom_layer(d7)  

        up1 = self.up1(last_down)  
        up2 = self.up2(torch.cat([up1, d7], 1))  
        up3 = self.up3(torch.cat([up2, d6], 1))  
        up4 = self.up4(torch.cat([up3, d5], 1))  
        up5 = self.up5(torch.cat([up4, d4], 1))  
        up6 = self.up6(torch.cat([up5, d3], 1))  
        up7 = self.up7(torch.cat([up6, d2], 1))  
    
        return self.top_layer(torch.cat([up7, d1], 1))  


    

In [7]:
def test_generator():
    x = torch.randn((1,1,256,256))
    model = UnetGenerator(in_channels=1, features=64)
    preds = model(x)
    print(preds.shape)
    print(preds)


In [8]:
test_generator()

torch.Size([1, 3, 256, 256])
tensor([[[[-0.3342,  0.4635, -0.4298,  ...,  0.1046,  0.2946, -0.2066],
          [ 0.1404, -0.6600, -0.0810,  ...,  0.9957,  0.8506,  0.3114],
          [ 0.0954,  0.8522, -0.5192,  ...,  0.4843,  0.0160, -0.5520],
          ...,
          [-0.3403,  0.6162,  0.0026,  ...,  0.4549,  0.8869, -0.2023],
          [-0.0148,  0.0651,  0.8700,  ...,  0.8561, -0.8034, -0.4202],
          [ 0.2011, -0.3886,  0.6157,  ...,  0.3942,  0.4235, -0.1619]],

         [[ 0.1848,  0.1869, -0.1388,  ...,  0.7511, -0.5337,  0.6015],
          [-0.2722,  0.7725,  0.5086,  ...,  0.2859, -0.1938,  0.1051],
          [-0.1155,  0.1739, -0.8968,  ...,  0.4809, -0.9269,  0.6018],
          ...,
          [-0.6942,  0.1296, -0.8125,  ..., -0.1160,  0.6992, -0.5192],
          [-0.7908, -0.4556,  0.3742,  ..., -0.8972,  0.1325,  0.4380],
          [-0.0088,  0.0510,  0.2996,  ..., -0.0143,  0.1451, -0.6114]],

         [[-0.5258,  0.9914, -0.8065,  ...,  0.8578, -0.2347,  0.1629],
 

## Dataset Loading

In [9]:
from PIL import Image
import numpy as np
import os
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir= root_dir
        self.list_files = os.listdir(self, root_dir)
        print(self.list_files)

    def __len__(self):
        return len(self.list_files)
    

    def __getitem__(self,indx):
        img_file = self.list_files

In [10]:
from PIL import Image
import numpy as np
import os
from torch.utils.data import Dataset

def get_files(dir):
    root_dir= dir
    list_files = os.listdir(root_dir)
    return list_files

def get_np_file(root_dir,file_name):
    np_path = os.path.join(root_dir,file_name)
    np_image= np.load(np_path)
    return np_image

 For training we need 2 kinds of dataset one is grayscale and another is ab

In [11]:
ab_root_dir = '/home/selvan/Documents/ab/ab'
l_root_dir = '/home/selvan/Documents/l'
ab_list_files = get_files(ab_root_dir)
l_list_files = get_files(l_root_dir)
ab_list_files, l_list_files

(['ab3.npy', 'ab1.npy', 'ab2.npy'], ['gray_scale.npy'])

In [12]:
gray_sacale_img= get_np_file(l_root_dir,l_list_files[0])

In [13]:
ab1_npy = get_np_file(ab_root_dir,ab_list_files[1])
ab2_npy = get_np_file(ab_root_dir,ab_list_files[2])
ab3_npy = get_np_file(ab_root_dir,ab_list_files[0])
ab1_npy.shape,ab2_npy.shape,ab3_npy.shape

((10000, 224, 224, 2), (10000, 224, 224, 2), (5000, 224, 224, 2))

### Now transformations

In [14]:
ab_train = np.concatenate((ab1_npy, ab2_npy, ab3_npy), axis=0) 

In [15]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

class NumpyDataset(Dataset):
    def __init__(self, data1, data2, gray_scale=True):
        self.data1 = data1  # Grayscale data
        self.data2 = data2  # AB data

        # Define transforms based on whether the images are grayscale or color
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5]) if gray_scale else transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def __len__(self):
        return min(len(self.data1), len(self.data2))  # Ensure you return the minimum length

    def __getitem__(self, idx):
        # Load grayscale image
        img1 = self.data1[idx]
        img1 = Image.fromarray(img1.astype(np.uint8), mode='L')  # Convert to grayscale

        # Load AB image
        img2 = self.data2[idx]
        img2 = Image.fromarray(img2.astype(np.uint8))  # Assuming AB is RGB or similar

        # Apply transformations
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        
        return img1, img2  # Return both images

# Create dataset
dataset = NumpyDataset(data1=gray_sacale_img, data2= ab_train , gray_scale=True)

# Create DataLoader
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)



## Training the model

In [16]:
for batch, (X, y) in enumerate(dataloader):
    print(batch, X.shape,y.shape)
    break

0 torch.Size([16, 1, 256, 256]) torch.Size([16, 2, 256, 256])


In [17]:
def generator_loss(gen_output, target, fake_output, lambda_L1=100):
    gan_loss = nn.BCEWithLogitsLoss()(fake_output, torch.ones_like(fake_output))
    l1_loss = nn.L1Loss()(gen_output, target)
    return gan_loss + lambda_L1 * l1_loss

def discriminator_loss(real_output, fake_output):
    real_loss = nn.BCEWithLogitsLoss()(real_output, torch.ones_like(real_output))
    fake_loss = nn.BCEWithLogitsLoss()(fake_output, torch.zeros_like(fake_output))
    return (real_loss + fake_loss) * 0.5


In [19]:
from tqdm import tqdm
import torch.optim as optim

# Assuming UnetGenerator and Discriminator are defined and instantiated
unet_generator = UnetGenerator(in_channels=1)  # Example initialization
discriminator = Discriminator()  # Example initialization

# Optimizers for the generator and discriminator
g_optimizer = optim.Adam(unet_generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

epochs = 3

for epoch in tqdm(range(epochs)):
    for batch, (X, y) in enumerate(dataloader):
        
        # Train the discriminator
        d_optimizer.zero_grad()
        fake_color = unet_generator(X)

        # Check shapes
        print(f"X shape: {X.shape}, fake_color shape: {fake_color.shape}, y shape: {y.shape}")

        fake_input = torch.cat([X, fake_color], dim=1)  # Concatenate inputs
        real_input = torch.cat([X, y], dim=1)  # Make sure y has compatible channels

        # Discriminator outputs
        fake_output = discriminator(fake_input.detach(), y.detach())  # Pass y if needed
        real_output = discriminator(real_input, y)

        # Calculate discriminator loss
        d_loss = discriminator_loss(real_output, fake_output)
        d_loss.backward()
        d_optimizer.step()
        
        # Train the generator
        g_optimizer.zero_grad()
        fake_output = discriminator(fake_input, y)  # Pass y again
        g_loss = generator_loss(fake_color, y, fake_output)
        g_loss.backward()
        g_optimizer.step()

        # Print losses at regular intervals
        if batch % 1000 == 0:
            print(f"Epoch [{epoch + 1}/{epochs}], Step [{batch}/{len(dataloader)}], "
                  f"D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}")



  0%|          | 0/3 [00:00<?, ?it/s]

: 