In [1]:
import os
import cv2
from os.path import join
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

from PIL import Image
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from model import Unet


In [2]:
SHAPE_DOWN = (256, 256)
FLAG = cv2.IMREAD_GRAYSCALE
BATCH_SIZE = 8 

In [3]:
def read_image(path: str, shape: tuple) -> np.array:
    img = cv2.imread(path, FLAG)
    return cv2.resize(img, shape)

In [4]:
# img_dir = './data/Lung Segmentation/CXR_png'
# mask_dir = './data/Lung Segmentation/masks'
# mask_names = os.listdir (mask_dir)

# img_files = os.listdir (img_dir)
# mask_files = os.listdir (mask_dir)

# img_files_name = [i.split(".png")[0] for i in os.listdir (img_dir)]

# pairs = [(i.split("_mask")[0] + '.png', i) for i in mask_files if (i.split("_mask")[0] + '.png') in img_files]

# img_reshape = [[read_image(join(img_dir, i[0]), SHAPE_DOWN), 
#                 read_image(join(mask_dir, i[1]), SHAPE_DOWN)] for i in tqdm(pairs)]

# with open('data.pickle', 'wb') as f:
#     pickle.dump(img_reshape, f)

In [5]:
with open('data.pickle', 'rb') as f:
    img_reshape = pickle.load(f)

In [6]:
np_img = np.array(img_reshape)
np_img = np_img/255
train_ds = np.expand_dims(np_img, axis=-1)

In [7]:
train_size = int(train_ds.shape[0]*0.8)
train_ds, test_ds = train_ds[:train_size,:], train_ds[train_size:,:]
val_size = int(train_ds.shape[0]*0.2)
train_size = len(train_ds) - val_size
train_ds, val_ds = random_split(train_ds, [train_size, val_size])

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 

model = Unet() 

model = model.to(device=device) 


learning_rate = 1e-4 
criterion = nn.MSELoss()
optimizer = optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
for epoch in range(5): 
    loss_ep = 0
    model.train()
    for batch_idx, data in enumerate(train_dl):
        data = data.to(device=device)
        
        optimizer.zero_grad()
        scores = model(data[:, 0].permute(0, 3, 1, 2).float())

        loss = criterion(scores.permute(0, 2, 3, 1), data[:, 1].float()) #permute(0, 3, 1, 2)
        
        loss.backward()
        optimizer.step()
        loss_ep += loss.item()
    print(f"Loss train in epoch {epoch} :::: {loss_ep/len(train_dl)}")
    
    model.eval()
    for batch_idx, data in enumerate(test_dl):
        data = data.to(device=device)
        scores = model(data[:, 0].permute(0, 3, 1, 2).float())
        loss = criterion(scores.permute(0, 2, 3, 1), data[:, 1].float()) #permute(0, 3, 1, 2)
        loss_ep += loss.item()
    print(f"Loss test in epoch {epoch} :::: {loss_ep/len(test_dl)}")
    

Loss train in epoch 0 :::: 0.050597319055510605
Loss test in epoch 0 :::: 0.2713375392059485
Loss train in epoch 1 :::: 0.022945955820867548
Loss test in epoch 1 :::: 0.08989862259477377
Loss train in epoch 2 :::: 0.020613926604552114
Loss test in epoch 2 :::: 0.08191868631790082
Loss train in epoch 3 :::: 0.019041431926028883
Loss test in epoch 3 :::: 0.07352655486514172
Loss train in epoch 4 :::: 0.018452227054892675
Loss test in epoch 4 :::: 0.07652916964143515


In [9]:
device = torch.device('cpu')

model = model.to(device=device)

In [10]:
torch.save(model, 'model.pt')

In [11]:
out = model(torch.from_numpy(np.expand_dims(np.expand_dims(np_img[0][0], axis=-1), axis=0)).permute(0, 3, 1, 2).float().to(device=device))

In [12]:
out = torch.round(out)

In [13]:
cv2.imshow('window_name', (out.permute(0, 2, 3, 1).reshape(256,256).detach().cpu().numpy()*255).astype(np.uint8))
cv2.waitKey(0)
cv2.destroyAllWindows()

In [14]:
out = model(torch.from_numpy(np.expand_dims(train_ds.dataset[0][0], axis=0)).permute(0, 3, 1, 2).float().to(device=device))

In [15]:
cv2.imshow('window_name', (out.permute(0, 2, 3, 1).reshape(256,256).detach().cpu().numpy()*255).astype(np.uint8))
cv2.waitKey(0)
cv2.destroyAllWindows()

In [16]:
cv2.imshow('window_name', train_ds.dataset[0][1])
cv2.waitKey(0)
cv2.destroyAllWindows()