In [None]:
from model.unet import Unet
import torch
import torch.nn as nn

In [None]:
from unet_datasets.dataset import Dataset, Normalization, RandomFlip, ToTensor
from torchvision import datasets, transforms
from PIL import Image
import numpy as np
from IPython.display import display

transformss = transforms.Compose([Normalization(mean=0.5, std=0.5), RandomFlip(), ToTensor()])

data_set = Dataset("./unet_datasets/train", transform=transformss)
data_set_val = Dataset("./unet_datasets/val", transform=transformss)

timg = Image.fromarray(np.load('./unet_datasets/train/input_000.npy'))
display(timg)

valimg = Image.fromarray(np.load('./unet_datasets/train/label_000.npy'))
display(valimg)

In [None]:
dataloader = torch.utils.data.DataLoader(data_set, batch_size=4, shuffle=True)
dataloader_val = torch.utils.data.DataLoader(data_set_val, batch_size=2, shuffle=True)

In [None]:
device = "cuda"
bs=4
lr=0.0005
epochs=100

unet = Unet(dim=64, mults=[1, 2, 4, 8], channel_scale=1).to(device)
optimizer = torch.optim.Adam(unet.parameters(), lr=lr)
criterion = nn.BCEWithLogitsLoss().to(device)

trainer = {
    'train_loss': [],
    'valid_loss': [],
    'valid_images': []
}

In [None]:
# time = torch.tensor(1, device=device).unsqueeze(0).repeat(bs)
# data = next(iter(dataloader))
# inputs = data['input'].to(device)

# print(inputs.shape, time.shape)

# out = unet(inputs, time)
# print(out.shape)

In [None]:
def visualize_output(out):
    array = out[0].cpu().detach().numpy()
    normalized_array = (array - array.min()) / (array.max() - array.min()) * 255
    normalized_array = normalized_array.astype(np.uint8)
    image = Image.fromarray(normalized_array[0])
    display(image)

In [None]:
from tqdm import tqdm

torch.cuda.empty_cache()
for epoch in range(epochs):
    unet.train()
    epoch_loss = 0
    time = torch.tensor(1, device=device).unsqueeze(0).repeat(4)
    for idx, data in tqdm(enumerate(dataloader)):
        unet.zero_grad()
        
        inputs = data['input'].to(device)
        labels = data['label'].to(device)
        
        output = unet(inputs, time)

        # backward pass
        optimizer.zero_grad()

        # print(f"output : {output.shape}, labels : {labels.shape}")
        loss = criterion(output, labels)
        loss.backward()

        optimizer.step()
        # print(loss.cpu().detach().item())

        epoch_loss += loss.cpu().detach().item()

    print(f"Epoch {epoch} - Train Loss : {epoch_loss}")
    trainer['train_loss'].append(epoch_loss)

    with torch.no_grad():
        unet.eval()
        valid_loss = 0
        for idx, data in tqdm(enumerate(dataloader_val)):
            inputs = data['input'].to(device)
            labels = data['label'].to(device)
            time = torch.tensor(1, device=device).unsqueeze(0).repeat(inputs.shape[0])
            
            output = unet(inputs, time)
    
            # backward pass
            optimizer.zero_grad()
            loss = criterion(output, labels)
            
            valid_loss += loss.cpu().detach().item()
    
        print(f"Epoch {epoch} - Valid Loss : {valid_loss}")
        visualize_output(output)
    trainer['valid_loss'].append(valid_loss)
    torch.cuda.empty_cache()

In [None]:
import matplotlib.pyplot as plt

plt.plot(trainer['train_loss'])
plt.show()
plt.close()

plt.plot(trainer['valid_loss'])
plt.show()
plt.close()