In [None]:
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import os
import time


model_name_list =  ["fpunet"]


In [None]:
fig_size = 128
batch_size = 64
device = "cuda"



from utils.FPUNet import FPUNet


def get_model(model_name):
    feature_scale = 2

    if model_name == "fpunet":
        model = FPUNet(in_channels=3, n_classes=3, feature_scale=feature_scale).to(device)

    return model


# net = get_model("fpunet")

# from torchinfo import summary  
# print(summary(net, (1,3,fig_size,fig_size)))

# test_input = torch.tensor(np.random.rand(1,3,fig_size,fig_size),dtype=torch.float32).to(device)
# with torch.no_grad():
#     net.eval()
#     out = net(test_input)
#     print(out.shape)
#     print(out[0,:,0,0])

In [None]:
file_name = "dataset/gen_data"

train_data = np.load(f'./{file_name}/train_data.npz', allow_pickle=True)
val_data = np.load(f'./{file_name}/val_data.npz', allow_pickle=True)
test_data = np.load(f'./{file_name}/test_data.npz', allow_pickle=True)

input_train = train_data["train_list"]
out_train = train_data["label_list"]

input_val = val_data["train_list"]
out_val = val_data["label_list"]

input_test = test_data["train_list"]
out_test = test_data["label_list"]

#modify the data: from dim 1 to dim 3
#each dim represent: free, occupied, unknown
def modify_data(data):
    #free, occupied, unknown
    res = torch.zeros((data.shape[0],3,data.shape[2],data.shape[3]))
    for i in range(data.shape[0]):
        res[i,0] = data[i] > 200
        res[i,1] = data[i] < 10
        #> 90 and < 110
        res[i, 2] = (data[i] > 90) & (data[i] < 110)
        # res[i,2] = data[i] > 90 
    return res

input_train = modify_data(torch.from_numpy(input_train.reshape((-1,1,fig_size,fig_size))))
out_train = modify_data(torch.from_numpy(out_train.reshape((-1,1,fig_size,fig_size))))

input_val = modify_data(torch.from_numpy(input_val.reshape((-1,1,fig_size,fig_size))))
out_val = modify_data(torch.from_numpy(out_val.reshape((-1,1,fig_size,fig_size))))

input_test = modify_data(torch.from_numpy(input_test.reshape((-1,1,fig_size,fig_size))))
out_test = modify_data(torch.from_numpy(out_test.reshape((-1,1,fig_size,fig_size))))



In [None]:
def trans_img(img):
    return np.transpose(img,(1,2,0))

plt.subplot(2,1,1)
plt.imshow(trans_img(input_train[10]))
plt.subplot(2,1,2)
plt.imshow(trans_img(out_train[10]))
plt.show()

In [None]:
import torch.utils.data as data_utils
#use the nosiy observations as input and perform the prediction
train_dataset = data_utils.TensorDataset(input_train, out_train)
train_loader = data_utils.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

#测试
test_dataset = data_utils.TensorDataset(input_test, out_test)
test_loader = data_utils.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

#验证
val_dataset = data_utils.TensorDataset(input_val, out_val)
val_loader = data_utils.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

def image_norm(x):
    return x*2.0/255 - 1.0

In [None]:
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

obs_weight = 0.6
class_weights = torch.tensor([0.2,obs_weight,0.2])  # free occupied unknown
num_epochs = 400

for now_model in model_name_list:
    #clear cuda cache
    torch.cuda.empty_cache()
    net = get_model(now_model)
    
    loss_func = nn.CrossEntropyLoss(weight=class_weights)
    loss_func = loss_func.to(device) 

    optimizer = optim.Adam(net.parameters(), lr=0.0002)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0.0001, T_max=10, last_epoch=-1)

    
    net = net.to(device)

    min_loss_train = 1e9
    min_loss_val = 1e9

    train_loss_list = []
    val_loss_list = []
    min_val_epoch = 0

    for epoch in tqdm(range(num_epochs)):
        
        net.train()  
        epoch_loss = 0
        running_loss = 0.0

        for i, (inputs, labels) in enumerate(train_loader):
            
            optimizer.zero_grad()  
            inputs,labels = inputs.to(torch.float32).to(device),labels.to(torch.float32).to(device)

            outputs = net(inputs)
            
            loss = loss_func(outputs, labels)

            loss.backward()
            optimizer.step()
            scheduler.step()

            running_loss += loss.item()
            epoch_loss += loss.item()
        
        epoch_loss /= len(train_loader)

        train_loss_list.append(epoch_loss) 
        if epoch % 50 == 0:
            #save
            torch.save(net.state_dict(), f"./trained_model/last_{now_model}_{fig_size}_{epoch}.pth")
        
        net.eval()  
        with torch.no_grad():
            val_loss = 0
            for i, (inputs, labels) in enumerate(val_loader):
                inputs,labels = inputs.to(torch.float32).to(device),labels.to(torch.float32).to(device)

                
                outputs = net(inputs)
                
                loss = loss_func(outputs, labels)
                val_loss += loss.item()
            val_loss /= len(val_loader)
        # print(f"epoch: {epoch}, train_loss: {epoch_loss}, val_loss: {val_loss}")
        val_loss_list.append(val_loss)

        if val_loss < min_loss_val:
            torch.save(net.state_dict(), f"./trained_model/best_{now_model}_{fig_size}_{num_epochs}.pth")
            min_loss_val = val_loss
            min_val_epoch = epoch
    
    torch.save(net.state_dict(), f"./trained_model/last_{now_model}_{fig_size}_{num_epochs}.pth")
    print(f"{now_model}, min_val_epoch:{min_val_epoch} min_loss_val:{min_loss_val}")

    x = np.arange(len(train_loss_list))

    plt.plot(x, train_loss_list, label='Train Loss')
    plt.plot(x, val_loss_list, label='Validation Loss')

    plt.legend()
    plt.grid()
    plt.xlabel('Epoch')
    plt.ylabel('Loss')

    plt.show()

In [None]:
torch.cuda.empty_cache()
def plot_images(ax,images):
    if images.shape[0] == 1:
        images = images[0]
    ax.imshow(images,cmap="gray")
def probability_to_label(origin):
    #origin: (N,3,128,128)
    #return: (N,128,128)
    res = torch.ones(origin.shape[0],origin.shape[2],origin.shape[3],dtype=torch.int64)
    for i in range(origin.shape[0]):
        index_mat = torch.argmax(origin[i],dim=0)
        res[i] = index_mat
        res[i][index_mat == 0] = 255
        res[i][index_mat == 1] = 0
        res[i][index_mat == 2] = 100
    
    return res

def evaluate_result(gt, predict):
    # gt: (N, 128, 128)
    # predict: (N, 128, 128)
    
    N, H, W = gt.shape
    
    total_tp = 0
    total_fn = 0
    total_fp = 0
    
    for i in range(N):
        tp = np.sum((gt[i] == 0) & (predict[i] == 0))
        fn = np.sum((gt[i] == 0) & (predict[i] != 0))
        fp = np.sum((gt[i] != 0) & (predict[i] == 0))
        
        total_tp += tp
        total_fn += fn
        total_fp += fp
    
    recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
    precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
    
    # Calculate F1 Score as a combined metric
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    return recall,precision,f1_score


n_channels = 3

for now_model in model_name_list:
    print(f"now_model: {now_model}")

    best_net = get_model(now_model)
    #Note that the training process is not so stable, you can also try the checkpoint of the last epoch
    state_dict = torch.load(f"./trained_model/best_{now_model}_{fig_size}_{num_epochs}.pth",weights_only=True)
    
    best_net.load_state_dict(state_dict)


    best_net = best_net.to(device)

    best_net.eval() 
    input_list = []
    gt_img_list = []
    predict_img_list = []

    for test_image,gt_img in test_loader:


        test_image = test_image.reshape(-1,n_channels,fig_size,fig_size).to(torch.float32).to(device)

        with torch.no_grad():
            predicted_img = best_net(test_image)

        for now_gt_img,now_pre, now_input in zip(probability_to_label(gt_img),probability_to_label(predicted_img),probability_to_label(test_image)):
            gt_img_list.append(now_gt_img.numpy().reshape(fig_size,fig_size))
            predict_img_list.append(now_pre.cpu().numpy().reshape(fig_size,fig_size))
            input_list.append(now_input.cpu().numpy().reshape(fig_size,fig_size))   
        
    input_list = np.array(input_list)
    gt_img_list = np.array(gt_img_list)
    predict_img_list = np.array(predict_img_list)

    recall,precision,f1_score = evaluate_result(gt_img_list, predict_img_list)
    print(f"{now_model} test recall: {recall}, precision: {precision}, f1_score: {f1_score}")

    # print(f"{now_model} test loss: {current_loss}")

    view_index = [1,10,100]
    fig_list = [input_list[view_index],gt_img_list[view_index],predict_img_list[view_index]]
    title_list = ['Input Label','GT Label','Predicted']
    eval_num = len(view_index)
    fig, axes = plt.subplots(eval_num, len(fig_list), figsize=(len(fig_list) *2.4, eval_num * 2.5))
    for i,ax_list in enumerate(axes):
        for j in range(len(fig_list)):

            current_ax = ax_list[j]
            current_img = fig_list[j][i]
            current_title = title_list[j]
            plot_images(current_ax,current_img)
            current_ax.set_title(current_title)
            current_ax.axis('off')