In [1]:
import os
import sys
import random
import numpy as np
import torch
from torch import nn
from torch import optim
from PIL import Image
from tqdm import tqdm
from tensorboardX import SummaryWriter
import albumentations as A
import yaml

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sys.path.append('./modules')

from UNet import UNet
from Dataset import Dataset
from ImageLoader import ImageLoader

In [3]:
data = yaml.load(open('./settings.yaml', 'r'), yaml.Loader)

images_path = data['images_path']
masks_path = data['masks_path']
image_patches_path = data['image_patches_path']
mask_patches_path = data['mask_patches_path']

patch_size = data['patch_size']
sigma = data['sigma']
num_neg_samples = data['num_neg_samples']

transform = A.Compose([
    A.RandomRotate90(p=1),
    A.Transpose(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
])

train_set = []
val_set = []

for i in [10, 20, 30, 40]:
    train_set.append("Bubbles_movie_01_x1987x2020x81_3cv2_NLM_template20_search62_inverted{}.png".format(i))
for i in [50]:
    val_set.append("Bubbles_movie_01_x1987x2020x81_3cv2_NLM_template20_search62_inverted{}.png".format(i))

In [4]:
train_ds = []
val_ds = []

for image_set, ds in [[train_set, train_ds], [val_set, val_ds]]:
    for image in image_set:
        print("Image:", image)
        patch_names = [file for file in os.listdir(os.path.join(image_patches_path, image)) if file[-4:] == '.npy']
        for patch in tqdm(range(len(patch_names))):
            image_patch = np.load(os.path.join(image_patches_path, image, patch_names[patch]))
            mask_patch = np.load(os.path.join(mask_patches_path, image, patch_names[patch]))
            ds.append(np.array([image_patch, mask_patch]))
            
train_ds = np.array(train_ds)
val_ds = np.array(val_ds)

train_ds = Dataset(train_ds[:, 0], train_ds[:, 1])
val_ds = Dataset(val_ds[:, 0], val_ds[:, 1])

train_loader = torch.utils.data.DataLoader(train_ds, shuffle=True, batch_size=10)
val_loader = torch.utils.data.DataLoader(val_ds, shuffle=False, batch_size=10)

epochs = 2000
lr = 0.5

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

unet = UNet(n_channels=1, n_classes=1).to(device)
lossFunc = nn.MSELoss()
opt = torch.optim.SGD(unet.parameters(), lr=lr)
scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=100, eta_min=0)

# writer = SummaryWriter('./2023.05.14 with_neg/runs/model_psize={:03}, dihedral_4'.format(patch_size))

Image: Bubbles_movie_01_x1987x2020x81_3cv2_NLM_template20_search62_inverted10.png


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 817/817 [00:03<00:00, 230.20it/s]


Image: Bubbles_movie_01_x1987x2020x81_3cv2_NLM_template20_search62_inverted20.png


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 836/836 [00:03<00:00, 240.38it/s]


Image: Bubbles_movie_01_x1987x2020x81_3cv2_NLM_template20_search62_inverted30.png


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 823/823 [00:03<00:00, 249.75it/s]


Image: Bubbles_movie_01_x1987x2020x81_3cv2_NLM_template20_search62_inverted40.png


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 825/825 [00:03<00:00, 244.42it/s]


Image: Bubbles_movie_01_x1987x2020x81_3cv2_NLM_template20_search62_inverted50.png


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 859/859 [00:03<00:00, 242.20it/s]


In [5]:
for epoch in tqdm(range(epochs)):
    unet.train()
    total_train_loss = 0
    total_val_loss = 0

    for x, y in train_loader:
        for batch_id in range(x.shape[0]):
            x, y = np.array(x), np.array(y)
            transformed = transform(image=x[batch_id], mask=y[batch_id])            
            x[batch_id] = transformed['image']
            y[batch_id] = transformed['mask']
        x, y = torch.Tensor(np.moveaxis(x, -1, 1)), torch.Tensor(np.moveaxis(y, -1, 1))

        x = x.to(device, dtype=torch.float)
        y = y.to(device, dtype=torch.float)
        
        pred = unet(x)
        loss = lossFunc(pred, y)
        total_train_loss += loss

        opt.zero_grad()
        loss.backward()
        opt.step()
        scheduler.step()

    with torch.no_grad():
        unet.eval()

        for x, y in val_loader:
            for batch_id in range(x.shape[0]):
                x, y = np.array(x), np.array(y)
                transformed = transform(image=x[batch_id], mask=y[batch_id])            
                x[batch_id] = transformed['image']
                y[batch_id] = transformed['mask']
            x, y = torch.Tensor(np.moveaxis(x, -1, 1)), torch.Tensor(np.moveaxis(y, -1, 1))
            
            x = x.to(device, dtype=torch.float)
            y = y.to(device, dtype=torch.float)
            
            pred = unet(x)                                
            loss = lossFunc(pred, y)
            total_val_loss += loss


    avg_train_loss = total_train_loss / len(train_loader)
    avg_val_loss = total_val_loss / len(val_loader)

#     writer.add_scalar('train_loss', avg_train_loss, epoch)
#     writer.add_scalar('val_loss', avg_val_loss, epoch)
    
#     if (epoch + 1) % 100 == 0:
#         model_param_path = './2023.05.14 with_neg/model_saves/model_psize={:03}, dihedral_4, epoch={:04}.pth'.format(patch_size, epoch + 1)
#         torch.save(unet.state_dict(), model_param_path)

# writer.flush()
# writer.close()

  0%|▏                                                                                                                    | 3/2000 [00:59<10:56:33, 19.73s/it]


KeyboardInterrupt: 