In [None]:
""" Basic Python Library """
from PIL import Image
import numpy as np
import tqdm
import cv2
import os
from glob import glob
import matplotlib.pyplot as plt

"""Deep Learning Library"""
import torch
from torch import nn,optim
import torch.nn.functional as F
import torchvision as T
from torch.autograd import Variable
from torchsummary import summary

In [None]:
seed=42
np.random.seed(42)
torch.manual_seed(42)
device=torch.device('cuda:0' if torch.cuda.is_available else 'cpu') #cuda:0 0 for indicating which gpu
print(device)

In [None]:
original=cv2.imread('original77.jpg')
original=cv2.resize(original,(224,224))
dress=cv2.imread('dress77.jpg')
dress=cv2.resize(dress,(224,224))
body=cv2.imread('body77.jpg')
body=cv2.resize(body,(224,224))

In [None]:
%matplotlib inline
plt.subplot(3,1,1)
plt.title('Original')
plt.imshow(cv2.cvtColor(original,cv2.COLOR_BGR2RGB))
plt.subplot(3,1,2)
plt.title('Dress')
plt.imshow(cv2.cvtColor(dress,cv2.COLOR_BGR2RGB))
plt.subplot(3,1,3)
plt.title('Body')
plt.imshow(cv2.cvtColor(body,cv2.COLOR_BGR2RGB))

In [None]:
dress_gray=cv2.cvtColor(dress,cv2.COLOR_BGR2GRAY)
body_gray=cv2.cvtColor(body,cv2.COLOR_BGR2GRAY)
print(np.unique(dress_gray))

In [None]:
# Encode Dress and Body
dress_gray=np.where(dress_gray<255,255,0)
body_gray=np.where(body_gray<255,255,0)

In [None]:
plt.subplot(2,1,1)
plt.title('Dress in gray format')
plt.imshow(dress_gray)
plt.subplot(2,1,2)
plt.title('Body in gray format')
plt.imshow(body_gray)

In [None]:
skin_gray=body_gray-dress_gray
plt.imshow(skin_gray)

In [None]:
plt.subplot(1,3,1)
plt.title('Person/Background')
body_gray = (255 - body_gray)/255
plt.imshow(body_gray)
plt.subplot(1,3,2)
plt.title('Skin')
skin_gray = (255 - skin_gray)/255
plt.imshow(skin_gray)
plt.subplot(1,3,3)
plt.title('Dress')
dress_gray = (255 - dress_gray)/255
plt.imshow(dress_gray)

In [None]:
# Combination of segmenting image
combine=np.zeros((224,224,3))
combine[:,:,0]=(1-skin_gray)
combine[:,:,1]=(1-dress_gray)
combine[:,:,2]=body_gray
plt.imshow(combine)

### Above method is simplest procedure to segmenting the objects in an image.

In [None]:
# Create custom dataset

class DressCollection(torch.utils.data.Dataset):
    def __init__(self,root,transform1,transform2):
        self.original=os.listdir(root+'/original')
        self.original_path=glob(os.path.join(root,self.original)+'/*')
        self.body=os.listdir(root+'/body')
        self.body_path=glob(os.path.join(root,self.body)+'/*')
        self.dress=os.lisdir(root+'/dress')
        self.dress_path=glob(os.path.join(root,self.dress)+'/*')
        self.transform1=transform1
        self.transform2=transform2
    def __len__(self):
        return len(self.original_path)
    def __getitem__(self,idx):
        original_image=self.transform1(Image.open(self.original_path[idx]))
        dress_image=cv2.imread(self.dress_path[idx])
        body_image=cv2.imread(self.body_path[idx])
        
        dress_image=cv2.resize(dress_image,(224,244))
        body_image=cv2.resize(body_image,(224,224))
        
        dress_image=np.where(dress_image<255,255,0)
        body_image=np.where(body_image<255,255,0)
        
        skin = body - dress
        bg = (255 - body)/255
        skin = (255 - skin)/255
        dress = (255 - dress)/255
        
        gt = np.zeros((224,224,3))
        gt[:,:,0] = (1-skin)
        gt[:,:,1] = (1-dress)
        gt[:,:,2] = bg
    
        gt = np.zeros((224,224,3))
        gt[:,:,0] = (1-skin)
        gt[:,:,1] = (1-dress)
        gt[:,:,2] = bg
        
        return original_image,self.transform(gt)

root='Addres of the Dataset stored'
transform1=T.transforms.Compose([T.transforms.ToTensor(),T.transforms.Resize(224)])
transfrom2=T.transforms.Compose([T.transforms.ToTensor()])

datasets=DressCollection(root,transform1,transform2)

original_images=torch.stack([imf for img,_ in datasets])
rgb_mean=origina_images.view(3,-1).mean(dim=1)
# Normalize the images to speed up the calculation
transform1=T.transforms.Compose([T.transforms.ToTensor(),\
                                 T.Normalize(mean=rgb_mmean,std=(1.0,1.0,1.0),\
                                 T.transforms.Resize(224)])
        
# Now again do the datasets methods
                                 
datasets=DressCollection(root,transform1,transform2)
BATCH_SIZE=32
dataloader=torch.utils.data.DataLoader(datasets,batch_size=BATCH_SIZE,shuffle=True)                

## UNET architecture first developed for biomedical application, now its use in global purpose and one is semantic segmentation of an image

### Below is the link provided for implementing a UNET model using pytorch DL framework.
[Unet 1](https://www.youtube.com/watch?v=u1loyDCoGbE)
[Unet 2](https://www.youtube.com/watch?v=IHq1t7NxS8k)

In [None]:
""" This method is credited to Abhishek Thakur  """
# You can use predefined model also
def downsample(in_channels,out_channels):
    conv=nn.Sequential(nn.Conv2d(in_channels,out_channels,3,1,1),
                                 nn.ReLU(),
                                 nn.Dropout2d(0.4),
                                 nn.Conv2d(out_channels,out_channels,3,1,1),
                                 nn.ReLU(),
                                 nn.Dropout2d(0.4))
        
    return conv
def concat_upsample(in_channels,out_channels,present_conv,previous_conv):
    up=nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2,stride=2,padding=0)
    value=torch.cat([up(present_conv),previous_conv],dim=1)
    conv=nn.Sequential(nn.Conv2d(value.size(1),out_channels,kernel_size=3,stride=1,padding=1),
                       nn.ReLU(),
                       nn.Dropout2d(0.4),
                       nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1),
                       nn.ReLU(),
                       nn.Dropout2d(0.4))
    return conv(value)
    
    
class Unet(nn.Module):
    def __init__(self):
        super(Unet,self).__init__()
        self.conv1=downsample(3,32)
        self.conv2=downsample(32,64)
        self.conv3=downsample(64,128)
        self.conv4=downsample(128,256)
        self.conv5=downsample(256,512)
        
       
        self.endlayer=nn.Conv2d(32,3,kernel_size=1,stride=1,padding=0)
        
        def forward(self,image):
        x1=self.conv1(image)
#        print(x1.size())
        p1=F.max_pool2d(x1,kernel_size=(2,2))
#        print(p1.size())
        x2=self.conv2(p1)
#        print(x2.size())
        p2=F.max_pool2d(x2,kernel_size=(2,2))
#        print(p2.size())
        x3=self.conv3(p2)
#        print(x3.size())
        p3=F.max_pool2d(x3,kernel_size=(2,2))
#        print(p3.size())
        x4=self.conv4(p3)
#        print(x4.size())
        p4=F.max_pool2d(x4,kernel_size=(2,2))
#        print(p4.size())
        x5=self.conv5(p4)
#        print(x5.size())
        x6=concat_upsample(512,256,x5,x4)
#        print(x6.size())
        x7=concat_upsample(256,128,x6,x3)
#        print(x7.size())
        x8=concat_upsample(128,64,x7,x2)
#        print(x8.size())
        x9=concat_upsample(64,32,x8,x1)
#        print(x9.size())
        x10=torch.sigmoid(self.endlayer(x9))
        return x10
        
model=Unet().to(device)

In [None]:
optimizer=optim.Adam(model.parameters(),lr=0.001)
gamma=2.0
alpha=0.25
def focal_loss(gamma,alpha,inputs,targets,logits=False):
    if logits:
        bce_loss=F.binary_cross_entropy_with_logits(inputs, targets,reduction='none')
    else:
        bce_loss=F.binary_cross_entropy(inputs, targets,reduction='none')
    pt=torch.exp(-bce_loss)
    focal_loss=alpha*(1-pt)**gamma*bce_loss
    return torch.mean(focal_loss)
summary(model,input_size=(3,224,224),device=device)

In [None]:
# Training the model
loss_per_epoch=[]
EPOCHS=100
for e in tqdm.tqdm(range(EPOCHS)):
    loss=0.0
    model.train()
    for raw,gt in dataloader:
        raw,gt=Variable(raw.to(device)),Variable(gt.to(device))
        optimizer.zero_grad()
        outs=model(raw)
        loss_fn=focal_loss(outs,gt)
        loss+=loss_fn.cpu().item()*gt.size(0)
        loss_fn.backward()
        optimizer.step()
    loss_per_epoch.append(loss/len(dataloader.sampler))
    print('Epoch_',e,'_lOSS:',loss_per_epoch[e])
#%%
""" Save and load your model """ 
torch.save(model.state_dict(),root+'/unet_model.pth')
model=Unet()
model.load_state_dict(torch.load(root+'/unet_model.pth'))
model.eval()

## Grabcut Algorithm
[grabcut cv2](https://www.pyimagesearch.com/2020/07/27/opencv-grabcut-foreground-segmentation-and-extraction/)

In [None]:
# Testing of the image
def grab_cut(image):
    image=cv2.resize(image,(224,224))
    image = cv2.resize(image,(224,224))
    
    mask = np.zeros(image.shape[:2],np.uint8)
    bgdModel = np.zeros((1,65),np.float64)
    fgdModel = np.zeros((1,65),np.float64)
    height, width = image.shape[:2]

    rect = (50,10,width-100,height-20)
    cv2.grabCut(image,mask,rect,bgdModel,fgdModel,5,cv2.GC_INIT_WITH_RECT)
    mask2 = np.where((mask==2)|(mask==0),0,1).astype('uint8')
    img2 = image*mask2[:,:,np.newaxis]
    img2[mask2 == 0] = (255, 255, 255)
    final = np.ones(image.shape,np.uint8)*0 + img2
    
    return mask, final
#%%
### READ NEW IMAGE ###
plt.figure(figsize=(16,8))
image = cv2.imread('Any test image format .jpg, .png')
plt.subplot(1,3,1)
plt.imshow(cv2.cvtColor(cv2.resize(image.copy(),(224,224)), cv2.COLOR_BGRA2RGB))

### GRUBCUT + PREDICTION ###
mask_test, test = grab_cut(image)
test = transform1(test)
if torch.cuda.is_available():
    test=test.to(device)
pred = model(test.unsqueeze(0))
pred=pred.detach().cpu().numpy().squeeze(0)
plt.subplot(1,3,2)
plt.imshow(pred)