In [16]:
import torch.nn as nn
import torch
import torchvision.transforms.functional as TF

In [17]:
class DoubleConv(nn.Module):
  def __init__(self,in_channels,out_channels):
    super(DoubleConv,self).__init__()
    self.conv = nn.Sequential(
            nn.Conv2d(in_channels,out_channels,3,1,1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels,out_channels,3,1,1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            
    )
  def forward(self,X):
    return self.conv(X)

In [18]:
class UNET(nn.Module):
  def __init__(self,in_c,o_c=1,features=[64,128,256,512]):
    super(UNET,self).__init__()
    self.downs=nn.ModuleList(



    )
    self.ups=nn.ModuleList(
                                


    )
    self.pool=nn.MaxPool2d(kernel_size=2,stride=2)
    for f in features:
      self.downs.append(DoubleConv(in_c,f))
      in_c=f
    for f in reversed(features):
      self.ups.append(
                        nn.ConvTranspose2d(f*2,f,kernel_size=2,stride=2)
      )
      self.ups.append(DoubleConv(f*2,f))
      in_c=f
    self.bottleneck=DoubleConv(features[-1],features[-1]*2)
    self.final_conv=nn.Conv2d(features[0],o_c,kernel_size=1)
  def forward(self,X):
    
    skip_connections=[]
    for down in self.downs:
      X=down(X)
      skip_connections.append(X)
      X=self.pool(X)
    X=self.bottleneck(X)
    skip_connections=skip_connections[::-1]
    for idx in range(0,len(self.ups),2):
      X=self.ups[idx](X)
      skip_connection=skip_connections[idx//2]
      if X.shape!=skip_connection.shape:
        X=TF.resize(X,size=skip_connection.shape[2:])
      c=torch.cat((skip_connection,X),dim=1)
      X=self.ups[idx+1](c)
    return self.final_conv(X)
#Choose input divisble by 16
#


In [19]:
def test():
  x=torch.randn((3,3,160,240))
  model=UNET(3,1)
  preds=model(x)
  print(preds.shape)
  print(x.shape)

In [20]:
test()

torch.Size([3, 1, 160, 240])
torch.Size([3, 3, 160, 240])


In [21]:
!curl -o data.zip 'https://storage.googleapis.com/kaggle-data-sets/1152755/1932509/bundle/archive.zip?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com%40kaggle-161607.iam.gserviceaccount.com%2F20210509%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20210509T091142Z&X-Goog-Expires=259199&X-Goog-SignedHeaders=host&X-Goog-Signature=41d6db3301e8cdce539919a8b3095746166c99b26b91ba3c65823f8675d3fec4a454b763c08821faebb79d1f7e48878c2b3bd698a635641cf350c876251a3a9db8bbe21eb215693b6a2cceebdc0a2255db89b21b99accd678bbed6f555a52d2d422262a044bf0ae74d16daeb6420ce185f21fad7a506c35c1761937c9c4faa362e916f0012926d5671dbab114f1c25873901bd13466e4596ae3d869dae4c77904f335fa08ab94252488e82ae809d732b9746456b9c6b710dc289f5107b831058d1bc3be6818d7fa2633a3216c05672ad7f30ed1da02dfff7f9383dde9425e634e938ad23a1e07fdb4798ad56fd25dd5c2194dc1fb45d50554b18e4042fc63606' -H 'User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:83.0) Gecko/20100101 Firefox/83.0' -H 'Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8' -H 'Accept-Language: en-US,en;q=0.5' --compressed -H 'Referer: https://www.kaggle.com/' -H 'DNT: 1' -H 'Connection: keep-alive' -H 'Upgrade-Insecure-Requests: 1' -H 'Sec-GPC: 1'

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1860M  100 1860M    0     0  71.0M      0  0:00:26  0:00:26 --:--:-- 80.5M


In [22]:
!ls

data.zip  sample_data


In [23]:
!unzip data.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: segmentation_full_body_tik_tok_2615_img/images/137_00390.png  
  inflating: segmentation_full_body_tik_tok_2615_img/images/137_00420.png  
  inflating: segmentation_full_body_tik_tok_2615_img/images/139_00120.png  
  inflating: segmentation_full_body_tik_tok_2615_img/images/139_00210.png  
  inflating: segmentation_full_body_tik_tok_2615_img/images/139_00240.png  
  inflating: segmentation_full_body_tik_tok_2615_img/images/141_00090.png  
  inflating: segmentation_full_body_tik_tok_2615_img/images/141_00120.png  
  inflating: segmentation_full_body_tik_tok_2615_img/images/141_00150.png  
  inflating: segmentation_full_body_tik_tok_2615_img/images/141_00180.png  
  inflating: segmentation_full_body_tik_tok_2615_img/images/141_00210.png  
  inflating: segmentation_full_body_tik_tok_2615_img/images/141_00240.png  
  inflating: segmentation_full_body_tik_tok_2615_img/images/141_00270.png  
  inflating: segmentati

In [24]:
import os
from PIL import Image
from torch.utils.data import Dataset,DataLoader

In [25]:
class TikTok(Dataset):
  def __init__(self,img_dir,mask_dir,transform=None):
    self.img_dir=img_dir
    self.mask_dir=mask_dir
    self.transform=transform
    self.images=os.listdir(self.img_dir)
  def __len__(self):
    return len(self.images)
  def __getitem__(self,i):
    img_path=os.path.join(self.img_dir,self.images[i])
    mask_path=os.path.join(self.mask_dir,self.images[i])

    image=np.array(Image.open(img_path).convert("RGB"))
    mask=np.array(Image.open(mask_path).convert("L"),dtype=np.float32)
    mask[mask==255.0]=1
    if self.transform is not None:
      aug=self.transform(image=image,mask=mask)
      image=aug['image']
      mask=aug['mask']
    return image,mask

In [26]:
!pip install albumentations==0.4.6
import albumentations 
from albumentations.pytorch import ToTensorV2




In [27]:
#Train
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim


import numpy as np


In [28]:
LR=1e-4
device="cuda" if torch.cuda.is_available() else "cpu"
batch_size=32
max_epochs=10
num_workers=2
image_height=160
image_width=240
PIN_Memory=True
Load_Model=False
Train_Img_dir='/content/segmentation_full_body_tik_tok_2615_img/images/'
Train_Mask_dir='/content/segmentation_full_body_tik_tok_2615_img/masks'

In [52]:
import torch
import torchvision

def save_checkpoints(state,filename="my_c.pth.tar"):
  print("saving...")
  torch.save(state,filename)

def load_checkpoint(c,model):
  print("loading_model...")
  model.load_state_dict(c["state_dict"])

def Loader(Img_dir,Mask_dir,train_transform,pin_memory):
  train_ds=TikTok(Img_dir,Mask_dir,transform=train_transform
                    )
  return DataLoader(train_ds,
                    batch_size=batch_size,
                  shuffle=True,
                  pin_memory=pin_memory)
def check_accuracy(x,y,model,device='cpu'):
  num_correct=0
  num_pixels=0
  model.eval()
  dice_sc=0
  with torch.no_grad():
      x=x.to(device)
      y=y.to(device).unsqueeze(1)
      preds=torch.sigmoid(model(x))
      preds=(preds>0.5).float()
      num_correct+=(preds==y).sum()
      num_pixels+=torch.numel(preds)
      dice_sc+=(2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )
  print(f"Got is {num_correct}/{num_pixels} with acc {(num_correct/num_pixels)*100} and dice score = {dice_sc}")
  model.train()
  torchvision.utils.save_image(preds,"/content/saved/pred_{}.png".format('r'))
  torchvision.utils.save_image(y,"/content/saved/{}.png".format('r'))


In [65]:
 val_transform=A.Compose(
      [
       
          A.Resize(height=image_height,width=image_width),
          A.Normalize(
              mean=[0.0,0.0,0.0],
              std=[1.0,1.0,1.0],
              max_pixel_value=255.0



          ),
       

          ToTensorV2(),



      ])
tl=Loader(Train_Img_dir,Train_Mask_dir,val_transform,PIN_Memory)
X_test,Y_test=next(iter(tl))



In [40]:
def train_fn(loader,model,optimzier,loss_fn,scaler):
  loop=tqdm(loader)
  for i,(data,targets) in enumerate(loop):
    data=data.to(device)
    targets=targets.float().unsqueeze(1).to(device)

    #Forward
    with torch.cuda.amp.autocast():
      predictions=model(data)
      loss=loss_fn(predictions,targets)
    optimzier.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimzier)
    scaler.update()
  

In [41]:
train_transform=A.Compose(
      [
       
          A.Resize(height=image_height,width=image_width),
          A.Rotate(limit=35,p=1.0),
          A.HorizontalFlip(p=0.2),
          A.VerticalFlip(p=0.1),
          A.Normalize(
              mean=[0.0,0.0,0.0],
              std=[1.0,1.0,1.0],
              max_pixel_value=255.0


          ),
          ToTensorV2(),



      ])


val_transform=A.Compose(
      [
       
          A.Resize(height=image_height,width=image_width),
          A.Normalize(
              mean=[0.0,0.0,0.0],
              std=[1.0,1.0,1.0],
              max_pixel_value=255.0


          ),
          ToTensorV2(),



      ])
tl=Loader(Train_Img_dir,Train_Mask_dir,train_transform,PIN_Memory)
model=UNET(in_c=3,o_c=1).to(device)
    
loss_fn=nn.BCEWithLogitsLoss()
optimizer=optim.Adam(model.parameters(),lr=LR)
scaler=torch.cuda.amp.GradScaler()
  

    #save model

  

In [42]:
!mkdir saved/


mkdir: cannot create directory ‘saved/’: File exists


In [43]:
device

'cuda'

In [44]:
X_test.shape

torch.Size([32, 3, 160, 240])

In [59]:
max_epochs=10

In [66]:
for epochs in range(max_epochs):
    train_fn(tl,model,optimizer,loss_fn,scaler)
    checkpoint={
        "state_dict":model.state_dict(),
        "optimizer":optimizer.state_dict(),
    }
    save_checkpoints(checkpoint)
    check_accuracy(X_test,Y_test,model,device)
    save_checkpoints(model,'/content/drive/MyDrive/Models/CV/Segmentation/model_unetv{}'.format(epochs))
    torch.save(model,'/content/drive/MyDrive/Models/CV/Segmentation/model_full_unetv{}'.format(epochs))


100%|██████████| 82/82 [01:47<00:00,  1.31s/it]


saving...
Got is 1214213/1228800 with acc 98.81290435791016 and dice score = 0.9384673237800598
saving...


100%|██████████| 82/82 [01:49<00:00,  1.34s/it]


saving...
Got is 1215959/1228800 with acc 98.95500183105469 and dice score = 0.9455527663230896
saving...


100%|██████████| 82/82 [01:48<00:00,  1.32s/it]


saving...
Got is 1217564/1228800 with acc 99.08561706542969 and dice score = 0.9520652294158936
saving...


100%|██████████| 82/82 [01:48<00:00,  1.32s/it]


saving...
Got is 1218429/1228800 with acc 99.15601348876953 and dice score = 0.9561496376991272
saving...


100%|██████████| 82/82 [01:48<00:00,  1.32s/it]


saving...
Got is 1218001/1228800 with acc 99.12117767333984 and dice score = 0.9545215368270874
saving...


100%|██████████| 82/82 [01:48<00:00,  1.32s/it]


saving...
Got is 1219249/1228800 with acc 99.22274017333984 and dice score = 0.9592949151992798
saving...


100%|██████████| 82/82 [01:47<00:00,  1.32s/it]


saving...
Got is 1216606/1228800 with acc 99.00765228271484 and dice score = 0.9494041800498962
saving...


100%|██████████| 82/82 [01:47<00:00,  1.32s/it]


saving...
Got is 1216261/1228800 with acc 98.97957611083984 and dice score = 0.9457399845123291
saving...


100%|██████████| 82/82 [01:48<00:00,  1.32s/it]


saving...
Got is 1219845/1228800 with acc 99.271240234375 and dice score = 0.9623174071311951
saving...


100%|██████████| 82/82 [01:48<00:00,  1.32s/it]


saving...
Got is 1220801/1228800 with acc 99.34903717041016 and dice score = 0.9658997654914856
saving...


In [63]:
  check_accuracy(X_test,Y_test,model,device)


Got is 1209348/1228800 with acc 98.4169921875 and dice score = 0.9181244373321533


In [67]:
save_checkpoints(model,'/content/drive/MyDrive/Models/CV/Segmentation/model_unetv3')
torch.save(model,'/content/drive/MyDrive/Models/CV/Segmentation/model_full_unetv3')


saving...
