In [27]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from celeb import CelebDatasetFast
from torchsummary import summary

In [28]:
transform = transforms.Compose([transforms.PILToTensor(), transforms.Lambda(lambda x: x/255), transforms.Resize([178,178], antialias=True)])

dataset_size = 1000
batch_size = 5
train_dataset = CelebDatasetFast(
    split='train', transform=transform,total=dataset_size)

test_dataset = CelebDatasetFast(
    split='test', transform=transform,total=dataset_size)

val_dataset = CelebDatasetFast(
    split='val', transform=transform,total=dataset_size)

train_loader = DataLoader(train_dataset, batch_size, True)
test_loader = DataLoader(test_dataset, batch_size, False)
val_loader = DataLoader(val_dataset, batch_size, False)

In [29]:

print(len(train_loader))
examples = iter(train_loader)
samples = next(examples)
mask,inp= samples[0]
target = samples[1]
print(inp.shape)
print(inp[0])
print(target.shape)

for k in range(0, 6, 2):
    i = inp[k].permute((1, 2, 0))
    plt.subplot(6, 2, k+1)
    plt.axis("off")
    if k == 0:
        plt.title("Input")
    plt.imshow(i)
    o = target[k].permute((1, 2, 0))
    plt.subplot(6, 2, k+2)
    plt.axis("off")
    if k == 0:
        plt.title("Target")
    plt.imshow(o)

plt.subplots_adjust(left=0.05,
                    bottom=0.05,
                    right=0.9,
                    top=0.9,
                    wspace=0,
                    hspace=0.1)
plt.show()

200


RuntimeError: The size of tensor a (256) must match the size of tensor b (178) at non-singleton dimension 2

In [30]:

import torchvision.transforms.functional as TF


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels), 
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels), 
            nn.ReLU(inplace=True),

            
        )

    def forward(self, x):
        return self.conv(x)
    
class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride = 2)

        # down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature
        
        for feature in reversed(features):
            self.ups.append(
                # output = s * (n-1) + k- 2*p
                nn.ConvTranspose2d(
                    feature*2, feature,kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2,feature))
        
        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, 1)
    
    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)

        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]
            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:], antialias=True)
            concat_skip = torch.cat((skip_connection, x), dim = 1)
            x = self.ups[idx+1](concat_skip)
        
        return self.final_conv(x)

In [31]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [32]:
model = UNET(3,3)
#summary(model,(3,178,178));
model.load_state_dict(torch.load("./models/20_loss.pth", map_location=device))

<All keys matched successfully>

In [33]:
model.eval();

In [34]:
import torchvision
def tensorToPIL(t):
    return torchvision.transforms.functional.to_pil_image(t, "RGB")


def save_tensor_as_image(t,name):
    img = tensorToPIL(t) 
    # img.show()
    img.save(name)

In [None]:
with torch.no_grad():
    h,w = 178,178
    model = model.to('cpu')

    test_loader = DataLoader(test_dataset, batch_size, True)
    examples = iter(test_loader)

    samples, targets = next(examples)

    inputs = samples[1]
    masks = samples[0]

    outputs = model(samples[1])


    # print(outputs.shape)
    # save_tensor_as_image(outputs[1], "ouput.png")
    rows = 3
    cols = 3
    for i in range(1,rows*cols,cols):
        input = inputs[i%batch_size].reshape(3, h, h).permute(1,2,0)
        # input = input.cpu().numpy()


        output = outputs[i%batch_size].reshape(3, h, h).permute(1,2,0)
        output = torch.clamp(output, 0, 1)
        # output = output.cpu().numpy()

        target = targets[i%batch_size].reshape(3, h, h).permute(1,2,0)
        # target = target.cpu().numpy()

        # print(output.shape)
        # input
        plt.subplot(rows,cols,i)
        plt.axis("off")
        if i == 1:
            plt.title("Input")
        plt.imshow(input)
        #output
        plt.subplot(rows,cols,i+1)
        plt.axis("off")
        if i == 1:
            plt.title("Output")
        plt.imshow(output)
        # ground truth
        plt.axis("off")
        plt.subplot(rows,cols,i+2)
        plt.axis("off")

        if i == 1:
            plt.title("Ground Truth")
        plt.imshow(target)
    plt.subplots_adjust(left=0.05,
                        bottom=0.05,
                        right=0.9,
                        top=0.9,
                        wspace=0.1,
                        hspace=0.1)
    plt.tight_layout()
    plt.savefig("output.png")
    plt.show()

In [6]:
from PIL import Image

In [7]:
from celeb import gen_line_mask
mask = torch.from_numpy(gen_line_mask((178, 178, 3), (8, 18))).permute((2,0,1))/255

In [8]:
def prepare_single_image(img):
    h,w = 178, 178

    i = Image.open(img)
    transform = transforms.Compose([ 
        transforms.PILToTensor(),
        transforms.Resize((h,w), antialias=False),
        transforms.Lambda(lambda x: x/255),
    ]) 

    img_tensor = transform(i)

    return torch.unsqueeze(img_tensor, 0)

In [9]:
model = UNET(3,3)
path_to_model = "./models/20_loss.pth"
model.load_state_dict(torch.load(path_to_model))
model.eval();

In [10]:
input_img = prepare_single_image("./trainthick/traininput/6.png")
print(input_img.shape)

torch.Size([1, 3, 178, 178])


In [11]:
output = model(input_img)

In [12]:
image_tensor = torch.clamp(output[0],0.,1.)

In [15]:

save_tensor_as_image(input_img[0], "input.png")
save_tensor_as_image(image_tensor, "ouput.png")


(eog:9331): EOG-CRITICAL **: 09:57:45.685: eog_image_get_file: assertion 'EOG_IS_IMAGE (img)' failed

(eog:9331): GLib-GIO-CRITICAL **: 09:57:45.685: g_file_equal: assertion 'G_IS_FILE (file1)' failed


In [26]:
import os
def gen_test_dataset(path_to_images, save_to_dir):

    h,w = 178, 178

    for image in os.listdir(path_to_images):
        mask = torch.from_numpy(gen_line_mask((178, 178, 3), (7, 12))).permute((2,0,1))/255
        path = os.path.join(path_to_images, image)
        if os.path.isfile(path):
            i = Image.open(path)
            transform = transforms.Compose([ 
                transforms.PILToTensor(),
                transforms.Resize((h,w), antialias=False),
                transforms.Lambda(lambda x: x/255),
            ]) 
            img_tensor = transform(i)
            masked_img_tensor = torch.maximum(mask,img_tensor)
            save_tensor_as_image(masked_img_tensor, f"./{save_to_dir}/{image}")


gen_test_dataset("./test_images", "./test_dataset")
    

In [20]:
import requests

for i in range(10):
    response = requests.get("https://thispersondoesnotexist.com/")

    if response.status_code == 200:
        # Open a file in binary write mode and write the content of the response to the file
        with open(f"./test_images/image_{i}.png", 'wb') as file:
            file.write(response.content)
        print("Image downloaded successfully!")
    else:
        print("Failed to download image:", response.status_code)

Image downloaded successfully!
Image downloaded successfully!
Image downloaded successfully!
Image downloaded successfully!
Image downloaded successfully!
Image downloaded successfully!
Image downloaded successfully!
Image downloaded successfully!
Image downloaded successfully!
Image downloaded successfully!


In [41]:
mask = torch.from_numpy(gen_line_mask((178, 178, 3), (7, 12))).permute((2,0,1))/255
image = Image.open("./6.png")
transform = transforms.Compose([ 
    transforms.PILToTensor(),
    transforms.Resize((178,178), antialias=False),
    transforms.Lambda(lambda x: x/255),
]) 
img_tensor = transform(image)
save_tensor_as_image(mask, "mask.png")
save_tensor_as_image(torch.maximum(mask, image_tensor), "masked_image.png")


(eog:9331): EOG-CRITICAL **: 21:46:12.691: eog_window_ui_settings_changed_cb: assertion 'G_IS_ACTION (user_data)' failed

(eog:9331): EOG-CRITICAL **: 21:46:12.691: eog_window_ui_settings_changed_cb: assertion 'G_IS_ACTION (user_data)' failed

(eog:9331): EOG-CRITICAL **: 21:46:12.691: eog_window_ui_settings_changed_cb: assertion 'G_IS_ACTION (user_data)' failed

(eog:9331): EOG-CRITICAL **: 21:46:12.691: eog_window_ui_settings_changed_cb: assertion 'G_IS_ACTION (user_data)' failed
