<a href="https://colab.research.google.com/github/skywalker00001/AlgorithmProject/blob/master/face_parsing2_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Version2.1

# Import

In [1]:
from google.colab import drive
drive.mount('/content/drive')
ROOT = 'drive/MyDrive/ACV/Project1'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# Checking out the GPU we have access to. This is output is from the google colab version. 
!nvidia-smi

Tue Mar 22 22:17:54 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    29W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
!pip install segmentation_models_pytorch
!pip install wandb -qqq
import wandb



In [4]:
# Login to wandb to log the model run and all the parameters
# 7229adacb32965027d73056a6927efd0365a00bc
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mskywalk3r[0m (use `wandb login --relogin` to force relogin)


In [5]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mskywalk3r[0m (use `wandb login --relogin` to force relogin)


True

In [None]:
#import wandb
import random
import numpy as np
import os
from tqdm import tqdm
from PIL import Image
import cv2
import time

import torch, torchvision
import torch.nn as nn
import torchvision.transforms as T
from torch import cuda
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp

In [None]:
# # Setting up the device for GPU usage

DEVICE = 'cuda' if cuda.is_available() else 'cpu'
print("DEVICE is: ", DEVICE)

# Set random seeds and deterministic pytorch for reproducibility
SEED = 42
torch.manual_seed(SEED) # pytorch random seed
np.random.seed(SEED) # numpy random seed
torch.backends.cudnn.deterministic = True

# Utils

In [None]:
def cross_entropy2d(input, target, weight=None, size_average=True):
    n, c, h, w = input.size()
    nt, ht, wt = target.size()

    # Handle inconsistent size between input and target
    if h != ht or w != wt:
        input = nn.functional.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True)

    input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
    target = target.view(-1)
    loss = nn.functional.cross_entropy(
        input, target, weight=weight, size_average=size_average, ignore_index=250
    )
    # print('lo1 shape', loss.shape)
    return loss

# predict, groundtruth are both [batch, imsize, imsize]
def get_miou(predict, groundtruth, num_classes=19, smoothing=1e-6):
    pred = predict.to('cpu')
    grdth = groundtruth.to('cpu')
    miou_sum = 0
    batch = predict.size()[0]
    for idx in range(batch):
        area_intersect_all = torch.zeros(num_classes).to('cpu')
        area_union_all = torch.zeros(num_classes).to('cpu')
        for cls_idx in range(num_classes):
            area_intersect = torch.sum((pred[idx] == grdth[idx]) * (pred[idx] == cls_idx))
            area_pred_label = torch.sum(pred[idx] == cls_idx)
            area_gt_label = torch.sum(grdth[idx] == cls_idx)
            area_union = area_pred_label + area_gt_label - area_intersect

            area_intersect_all[cls_idx] += area_intersect + smoothing
            area_union_all[cls_idx] += area_union + smoothing

        iou_all = area_intersect_all / area_union_all * 100.0
        miou = iou_all.mean()

        miou_sum += miou

    return miou_sum

In [None]:
def get_my_palette(impath):
    img = Image.open(impath) 
    palette = img.getpalette()
    return palette

def put_my_palette(img, pale):
    img = img.putpalette(pale)
    return img

# ts: [512, 512] after * 255
def tensor2uint18(ts):
    ts = ts.long().cpu().numpy()
    ts = ts.astype(np.uint8)
    return ts

In [None]:
# images: [batch, 3, 512, 512], tensor
# labels: [batch, 512, 512], tensor
def wandb_log_image_table(images, predicted, labels, pale):
    "Log a wandb.Table with (img, pred, target, scores)"
    table = wandb.Table(columns=["image", "pred", "target"])
    images, predicted, labels = images.cpu(), predicted.cpu(), labels.cpu()
    #my_pale = 
    for img, pred, targ in zip(images, predicted, labels):
        # img
        img = img.long().numpy()
        img = np.transpose(img, (1, 2, 0))

        # pred
        pred = tensor2uint18(pred)
        pred = Image.fromarray(pred)
        pred.putpalette(pale)

        # targ
        targ = tensor2uint18(targ)
        targ = Image.fromarray(targ)
        targ.putpalette(pale)

        # add_data
        table.add_data(wandb.Image(img), wandb.Image(pred), wandb.Image(targ))
    #wandb.log({"predictions_table":table}, commit=False)
    wandb.log({"predictions_table":table}, commit=False)

In [None]:
def save_model(model, optimizer, epoch, miou, PATH):
    torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch,
            'miou': miou
            }, PATH)
# Helper function to print time 
def total_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

# def load_model(model, optimizer, PATH):
#     checkpoint = torch.load(PATH)
#     model.load_state_dict(checkpoint['model_state_dict'])
#     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#     epoch = checkpoint['epoch']
#     miou = checkpoint['miou']
#     return model, optimizer, epoch, miou

# Data Loader

In [None]:

'''
    mode="train/val/test"
    img.shape: torch.Size([3, 512, 512])
    label.shape: torch.Size([1, 512, 512])
'''
class FaceParse_Dataset(Dataset):
    def __init__(self, img_path, label_path, transform_img, transform_label, mode="train"): 
        self.img_path = img_path
        self.label_path = label_path
        self.transform_img = transform_img
        self.transform_label = transform_label
        self.train_dataset = []
        self.val_dataset = []
        self.test_dataset = []
        self.mode = mode
        self.preprocess()
        
        if mode == "train":
            self.num_images = len(self.train_dataset)
        elif mode == "val":
            self.num_images = len(self.val_dataset)
        else :
            self.num_images = len(self.test_dataset)

    def preprocess(self):
        
        for i in range(len([name for name in os.listdir(self.img_path) if os.path.isfile(os.path.join(self.img_path, name))])):
            img_path = os.path.join(self.img_path, str(i)+'.jpg')
            # label_path = os.path.join(self.label_path, str(i)+'.png')
            #print (img_path, label_path) 
            if self.mode == "train":
                label_path = os.path.join(self.label_path, str(i)+'.png')
                self.train_dataset.append([img_path, label_path])
            elif self.mode == "val":
                label_path = os.path.join(self.label_path, str(i)+'.png')
                self.val_dataset.append([img_path, label_path])
            elif self.mode == "test":
                self.test_dataset.append(img_path)
            
        print(f'Finished preprocessing the CelebA dataset in {self.mode} mode...')

    def __getitem__(self, index):
        if self.mode == "test":
            dataset = self.test_dataset
            img_path = dataset[index]
            image = Image.open(img_path)
            return self.transform_img(image)
        else: 
            dataset = self.train_dataset if self.mode == "train" else self.val_dataset
            img_path, label_path = dataset[index]
            image = Image.open(img_path)
            label = Image.open(label_path)
            return self.transform_img(image), self.transform_label(label)

    def __len__(self):
        """Return the number of images."""
        return self.num_images

class Data_Loader():
    def __init__(self, img_path, label_path, image_size, batch_size, mode):
        self.img_path = img_path
        self.label_path = label_path
        self.imsize = image_size
        self.batch = batch_size
        self.mode = mode

    def transform_img(self, resize, totensor, normalize, centercrop):
        options = []
        if centercrop:
            options.append(T.CenterCrop(160))
        if resize:
            options.append(T.Resize((self.imsize,self.imsize)))
        if totensor:
            options.append(T.ToTensor())
        if normalize:
            options.append(T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
        transform = T.Compose(options)
        return transform

    def transform_label(self, resize, totensor, normalize, centercrop):
        options = []
        if centercrop:
            options.append(T.CenterCrop(160))
        if resize:
            options.append(T.Resize((self.imsize,self.imsize)))
        if totensor:
            options.append(T.ToTensor())
        if normalize:
            options.append(T.Normalize((0, 0, 0), (0, 0, 0)))
        transform = T.Compose(options)
        return transform

    def loader(self):
        transform_img = self.transform_img(True, True, True, False) 
        transform_label = self.transform_label(True, True, False, False)  
        dataset = FaceParse_Dataset(self.img_path, self.label_path, transform_img, transform_label, self.mode)

        loader = DataLoader(dataset=dataset,
                            batch_size=self.batch,
                            #shuffle=True,
                            shuffle=(self.mode=="train"),
                            num_workers=2,
                            drop_last=False)
        return loader

# Trainer

In [None]:
class Trainer(object):
    def __init__(self, model, optimizer, train_loader, val_loader, config):

        self.model_version = config["MODEL_VERSION"]
        self.device = config["DEVICE"]
        self.pale = config["PALETTE"]

        # Data loader
        self.train_loader = train_loader
        self.val_loader = val_loader

        # exact model and optimizer
        self.model = model
        self.optimizer = optimizer

        self.num_epoch = config["NUM_EPOCH"]
        self.start_epoch = config["START_EPOCH"]
        self.num_classes = config["NUM_CLASSES"]

        # Save model
        self.model_save_step = config["MODEL_SAVE_STEP"]
        self.model_save_path = config["MODEL_SAVE_PATH"]

        self.best_mious = config["BEST_MIOUS"]
        self.best_epochs = config["BEST_EPOCHS"]

        self.need_log_images = config["LOG_IMAGES"]
        self.smoothing = config["SMOOTHING"]
    #     self.build_model

    # def train(self):
    #     self.model = 

    def train(self):
        for epoch in range(self.start_epoch, self.start_epoch+ self.num_epoch):
            # print(self.model)
            # print(self.model.type)
            self.model.train()

            train_loss = 0
            train_miou = 0
            train_num = 0

            val_loss = 0
            val_miou = 0
            val_num = 0

            # train
            with tqdm(total=len(self.train_loader), desc="training progress bar") as progress_bar:
                progress_bar.set_description('Epoch: {}/{} training'.format(epoch+1, self.start_epoch+ self.num_epoch ))
                for batch, (imgs, labels) in enumerate(self.train_loader):
                    # imgs: [batch, 3, imsize, imsize]
                    # labels: [batch, 1, imsize, imsize]
                    imgs, labels = imgs.to(self.device), labels.to(self.device)
                    #print(labels)

                    # Forward
                    # outputs: [batch, num_class, imsize, imsize]
                    outputs = self.model(imgs)

                    # *255 to restore the real pixel value (transform.ToTensor has / 255)
                    labels[:, 0, :, :] = labels[:, 0, :, :] * 255.0
                    # labels_real_plain: [batch, imsize, imsize]
                    labels_real_plain = labels[:, 0, :, :]
                    # print(labels_real_plain)

                    # compute loss
                    # print('outputs shape: ', outputs.shape)
                    # print('labels_real_plain shape: ', labels_real_plain.long().shape)
                    loss = cross_entropy2d(outputs, labels_real_plain.long())
                    # print('loss shape: ', loss.shape)
                    # print('imgs.size(0): ', imgs.size(0))
                    # print('loss.data:', loss.data)
                    train_loss += loss.data * imgs.size(0)
                    # Backprop the gradient and update parameters
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                    # compute miou
                    # pred_mask: [batch, imsize, imsize]
                    pred_mask = torch.argmax(outputs, dim=1)
                    train_miou += get_miou(pred_mask, labels_real_plain, num_classes=self.num_classes, smoothing=self.smoothing)
                    # count total_num
                    train_num += imgs.size(0)

                    if(batch) % 10 == 9:
                        progress_bar.set_postfix(batch='{}'.format(batch),
                                                 train_miou="{:.2f}%".format(train_miou / train_num),
                                                 train_loss='{:.5f}'.format(train_loss / train_num))
                    progress_bar.update(1)

            # update wandb
            wandb.log({"train_loss": (train_loss / train_num), "train_miou": (train_miou / train_num), \
                        "epoch": epoch+1}, step = epoch - self.start_epoch)
            # Log validation metrics
            # Needchange
            val_loss, val_miou = self.valid(epoch, log_images=self.need_log_images, batch_idx=0)

            wandb.log({"val_loss": val_loss, "val_miou": val_miou}, step=epoch - self.start_epoch)
            #print(f"Train Loss: {train_loss:.3f}, Valid Loss: {val_loss:3f}, accuracy: {accuracy:.2f}")
        
            # Save model based on miou
            min_best_mious = min(self.best_mious)
            indexof_min_best_mious = self.best_mious.index(min(self.best_mious))
            if val_miou > min_best_mious:   # replace the model which has the min miou with current model.
                self.best_mious[indexof_min_best_mious] = val_miou
                self.best_epochs[indexof_min_best_mious] = epoch
                save_model(self.model, self.optimizer, epoch+1, val_miou, \
                           os.path.join(self.model_save_path, '{}_MODEL{}.pth').format(indexof_min_best_mious, self.model_version))
                
                #os.path.join(self.model_save_path, '{}_MODEL.pth'.format(epoch + 1)))
                print("Saving best model at epoch {} with miou {}".format(epoch+1, val_miou))
            #            os.path.join(self.model_save_path, '{}_G.pth'.format(step + 1)))

            # Save model based on epoch
            #Needchange
            if epoch % self.model_save_step == (self.model_save_step - 1):   # each 10 epochs save the model once.

                save_model(self.model, self.optimizer, epoch+1, val_miou, \
                           os.path.join(self.model_save_path, 'FUNDAMODEL{}.pth').format(self.model_version))
                
                #os.path.join(self.model_save_path, '{}_MODEL.pth'.format(epoch + 1)))
                print("Saving fundamodel at epoch {} with miou {}".format(epoch+1, val_miou))
            #            os.path.join(self.model_save_path, '{}_G.pth'.format(step + 1)))   

        return self.best_mious ,self.best_epochs

    # def load_pretrained_model(self):
    #     self.G.load_state_dict(torch.load(os.path.join(
    #         self.model_save_path, '{}_G.pth'.format(self.pretrained_model))))
    #     print('loaded trained models (step: {})..!'.format(self.pretrained_model))

    def valid(self, epoch, log_images=False, batch_idx=0):
        "Compute performance of the model on the validation dataset and log a wandb.Table"
        self.model.eval()
        val_loss = 0.
        val_miou = 0.
        val_num = 0

        with torch.inference_mode():
            with tqdm(total=len(val_loader), desc="validating progress bar") as progress_bar:
                progress_bar.set_description('epoch: {}/{} validating'.format(epoch+1, self.start_epoch+ self.num_epoch))

                for batch, (imgs, labels) in enumerate(self.val_loader):
                    # imgs: [batch, 3, imsize, imsize]
                    # labels: [batch, 1, imsize, imsize]
                    imgs, labels = imgs.to(self.device), labels.to(self.device)
                    # outputs: [batch, num_class, imsize, imsize]
                    outputs = self.model(imgs)
                    # *255 to restore the real pixel value (transform.ToTensor has / 255)
                    labels[:, 0, :, :] = labels[:, 0, :, :] * 255.0
                    # labels_real_plain: [batch, imsize, imsize]
                    labels_real_plain = labels[:, 0, :, :]
                    # compute loss (average loss over one batch)
                    loss = cross_entropy2d(outputs, labels_real_plain.long())
                    val_loss += loss.data * imgs.size(0)
                    # compute miou
                    # pred_mask: [batch, imsize, imsize]
                    pred_mask = torch.argmax(outputs, dim=1)
                    val_miou += get_miou(pred_mask, labels_real_plain, num_classes=self.num_classes, smoothing=self.smoothing)
                    # count val_num
                    val_num += imgs.size(0)

                    # Log one batch of images to the dashboard, always same batch_idx.
                    # Needchange
                    #if batch==batch_idx and log_images:
                    if (epoch % self.model_save_step == (self.model_save_step - 1)) and batch==batch_idx and log_images:
                        wandb_log_image_table(imgs*255, pred_mask, labels_real_plain, pale=self.pale)
                    # update progress_bar
                    progress_bar.set_postfix(miou="{:.2f}%".format(val_miou / val_num),
                                              loss='{:.5f}'.format(val_loss / val_num))
                    progress_bar.update(1)

        return val_loss / val_num, val_miou / val_num



# Tester

In [None]:

class Tester(object):
    def __init__(self, model, test_loader, config):

        #self.model_version = config["MODEL_VERSION"]
        self.device = config["DEVICE"]
        self.pale = config["PALETTE"]

        # exact model and optimizer
        self.model = model
        # Data loader
        self.test_loader = test_loader

        # self.optimizer = optimizer
        # self.imsize = config["IMSIZE"]
        # self.num_classes = config["NUM_CLASSES"]
        self.test_batch_size = config["TEST_BATCH_SIZE"]
        self.gray_label_path = os.path.join(config["RESULTS_PATH"], 'gray')
        self.color_label_path = os.path.join(config["RESULTS_PATH"], 'color')

        self.making_files()

        # # Model hyper-parameters
        # self.imsize = config["IMSIZE"]
        # self.train_batch_size = config["TRAIN_BATCH_SIZE"]
        # self.val_batch_size = config["VAL_BATCH_SIZE"]
        # self.num_workers = config.num_workers
        # self.g_lr = config["LEARNING_RATE"]
        # self.lr_decay = config["LR_DECAY"]
        # self.beta1 = config["BETA1"]
        # self.beta2 = config["BETA2"]
        # self.model = config["MODEL"]
        # self.num_epoch = config["NUM_EPOCH"]
        # self.start_epoch = config["START_EPOCH"]

        # # Save model
        # self.model_save_step = config["MODEL_SAVE_STEP"]
        # self.model_save_path = config["MODEL_SAVE_PATH"]

        # self.best_miou = config["BEST_MIOU"]
        # self.best_epoch = config["BEST_EPOCH"]


    # def test2(self):
    #     self.model.eval() 
    #     batch_num = int(self.test_size / self.batch_size)

    #     for i in range(batch_num):
    #         print (i)
    #         imgs = []
    #         for j in range(self.batch_size):
    #             path = test_paths[i * self.batch_size + j]
    #             img = transform(Image.open(path))
    #             imgs.append(img)
    #         imgs = torch.stack(imgs) 
    #         imgs = imgs.cuda()
    #         labels_predict = self.G(imgs)
    #         labels_predict_plain = generate_label_plain(labels_predict)
    #         labels_predict_color = generate_label(labels_predict)
    #         for k in range(self.batch_size):
    #             cv2.imwrite(os.path.join(self.test_label_path, str(i * self.batch_size + k) +'.png'), labels_predict_plain[k])
    #             save_image(labels_predict_color[k], os.path.join(self.test_color_label_path, str(i * self.batch_size + k) +'.png'))
    def making_files(self):
        if not os.path.exists(self.gray_label_path):
              os.mkdir(self.gray_label_path)
              print("New gray file!")
        if not os.path.exists(self.color_label_path):
              os.mkdir(self.color_label_path)
              print("New color file!")

    def test(self):
        "Compute performance of the model on the test dataset."
        self.model.eval()
  
        # test_loss = 0.
        # test_miou = 0.
        test_num = 0
        result_num = []


        with torch.inference_mode():
            for batch, imgs in enumerate(self.test_loader):
                # imgs: [batch, 3, imsize, imsize]
                # labels: [batch, 1, imsize, imsize]
                imgs = imgs.to(self.device)
                # outputs: [batch, num_class, imsize, imsize]
                outputs = self.model(imgs)
                # *255 to restore the real pixel value (transform.ToTensor has / 255)
                # labels[:, 0, :, :] = labels[:, 0, :, :] * 255.0
                # labels_real_plain: [batch, imsize, imsize]
                # labels_real_plain = labels[:, 0, :, :]
                # compute loss (average loss over one batch)
                # loss = cross_entropy2d(outputs, labels_real_plain.long())
                # test_loss += loss.data * imgs.size(0)
                # compute miou
                # pred_mask_tensor: [batch, imsize, imsize]
                pred_mask_tensor = torch.argmax(outputs, dim=1)
                pred_mask_numpy = pred_mask_tensor.cpu().numpy()
                # test_miou += get_miou(pred_mask, labels_real_plain, num_classes=self.num_classes)
                # print("size: ", pred_mask_tensor.shape)
                for k in range(imgs.size(0)):
                    # gray_piciture
                    cv2.imwrite(os.path.join(self.gray_label_path, str(test_num + k) +'.png'), pred_mask_numpy[k])
                    # print("Gray pic {}.png".format(test_num + k))

                    # color_picture
                    color_label = tensor2uint18(pred_mask_tensor[k])
                    color_label = Image.fromarray(color_label)
                    color_label.putpalette(self.pale)
                    color_label.save(os.path.join(self.color_label_path, str(test_num + k) +'.png'))

                    result_num.append(test_num + k)
                    # print("Color pic {}.png".format(test_num + k))
                    #cv2.imwrite(os.path.join(self.gray_label_path, str(test_num + k) +'.png'), pred_mask[k])

                # count test_num
                test_num += imgs.size(0)
                    # save_image(labels_predict_color[k], os.path.join(self.test_color_label_path, str(i * self.batch_size + k) +'.png'))
                # Log one batch of images to the dashboard, always same batch_idx.
                # if epoch % 10 == 0 and batch==batch_idx and log_images:
                #     wandb_log_image_table(imgs*255, pred_mask, labels_real_plain, pale=self.pale)
        
        print("Testing {} results has completed!".format(test_num))
        return result_num
        # return test_loss / test_num, test_miou / test_num




# Parameters

In [None]:
Parameters = {
    "MODEL_VERSION": '2.1',
    "MODEL_LOAD_VERSION": '2.1',
    "DEVICE": DEVICE,

    # Needchange
    "NUM_EPOCH": 40,
    "START_EPOCH": 0,
    "NUM_CLASSES": 19,
    "IMSIZE": 512,
    "TRAIN_BATCH_SIZE": 16,    # input batch size for training (default: 64)
    "VAL_BATCH_SIZE": 64,    # input batch size for testing (default: 1000)
    "TEST_BATCH_SIZE": 64, 
    #"TRAIN_EPOCHS": 51,        # number of epochs to train (default: 10)
    #"SEED": 42,               # random seed (default: 42)

    # Model Para
    "LEARNING_RATE": 1e-4 ,   # learning rate (default: 0.01)
    "LR_DECAY": 0.95,
    "BETA1": 0.5,
    "BETA2": 0.999,

    # Path
    "TRAIN_PATH": os.path.join(ROOT, 'train'),
    "VAL_PATH": os.path.join(ROOT, 'val'),
    "TEST_PATH": os.path.join(ROOT, 'test'),
    "RESULTS_PATH": os.path.join(ROOT, 'results'),

    # Save
    "MODEL_SAVE_STEP": 5,
    "MODEL_SAVE_PATH": os.path.join(ROOT, 'models'),

    # Load
    "MODEL_IF_LOAD": False,
    "MODEL_LOAD_PATH": os.path.join(ROOT, 'models'),

    # Best model
    # "BEST_MIOUS": 5*[0],
    # "BEST_EPOCHS": 5*[0],
    "EXPECTED_MODEL_NUMBER": 4,

    "LOG_IMAGES": False,
    "SMOOTHING": 1e-6,
}

Parameters["BEST_MIOUS"] = Parameters["EXPECTED_MODEL_NUMBER"] * [0]
Parameters["BEST_EPOCHS"] = Parameters["EXPECTED_MODEL_NUMBER"] * [0]
# get palette
Parameters["PALETTE"] = get_my_palette(os.path.join(Parameters["TRAIN_PATH"], 'train_mask/1.png'))
# # Define model and optimizer
# Parameters["MODEL"] = smp.Unet(
#         encoder_name="resnet34",
#         encoder_weights="imagenet",
#         in_channels=3,
#         classes=Parameters["NUM_CLASSES"],
#     ).to(Parameters["DEVICE"])

# Parameters["OPTIMIZER"] = torch.optim.Adam(filter(lambda p: p.requires_grad, Parameters["MODEL"].parameters()), \
#                               Parameters["LEARNING_RATE"], [Parameters["BETA1"], Parameters["BETA2"]])

In [None]:
train_img_path = os.path.join(Parameters["TRAIN_PATH"], 'train_image')
train_label_path = os.path.join(Parameters["TRAIN_PATH"], 'train_mask')
val_img_path = os.path.join(Parameters["VAL_PATH"], 'val_image')
val_label_path = os.path.join(Parameters["VAL_PATH"], 'val_mask')
test_img_path = os.path.join(Parameters["TEST_PATH"], 'test_image')

In [None]:
# train_img_path = os.path.join(Parameters["TRAIN_PATH"], 'train_image')
# train_label_path = os.path.join(Parameters["TRAIN_PATH"], 'train_mask')
# val_img_path = os.path.join(Parameters["VAL_PATH"], 'val_image')
# val_label_path = os.path.join(Parameters["VAL_PATH"], 'val_mask')
# test_img_path = os.path.join(Parameters["TEST_PATH"], 'test_image')

# train_loader: img([8, 3, 512, 512]), label([8, 1, 512, 512])
train_loader = Data_Loader(train_img_path, train_label_path, \
              Parameters["IMSIZE"], Parameters["TRAIN_BATCH_SIZE"], "train").loader()
val_loader = Data_Loader(val_img_path, val_label_path, \
              Parameters["IMSIZE"], Parameters["VAL_BATCH_SIZE"], "val").loader()
test_loader = Data_Loader(test_img_path, None, \
              Parameters["IMSIZE"], Parameters["TEST_BATCH_SIZE"], "test").loader()

# Main

In [None]:
# print()
# print("loaded_best_mious: " + str(Parameters["BEST_MIOUS"]))
# print("loaded_best_epochs: " + str(Parameters["BEST_EPOCHS"]))


In [None]:

with wandb.init(
    # config = Parameters,
    project="Face_Parsing"+ Parameters["MODEL_VERSION"],
    ):
    # config = wandb.config

    # Define model and optimizer
    model = smp.Unet(
            encoder_name="resnet34",
            encoder_weights="imagenet",
            in_channels=3,
            classes=Parameters["NUM_CLASSES"],
        ).to(Parameters["DEVICE"])

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), \
                                 Parameters["LEARNING_RATE"], [Parameters["BETA1"], Parameters["BETA2"]])
    start_time = time.time()

    # if need, load the paramters of the model
    if (Parameters["MODEL_IF_LOAD"]):      
        # update stored best_mious and best_epochs
        num = Parameters["EXPECTED_MODEL_NUMBER"]
        b_mious, b_epochs = [], []
        for i in range(num):
            PATH = os.path.join(Parameters["MODEL_LOAD_PATH"], \
                                                  "{}_MODEL{}.pth".format(i, Parameters["MODEL_LOAD_VERSION"]))
            checkpoint = torch.load(PATH)
            # loaded_model.load_state_dict(checkpoint['model_state_dict'])
            b_epochs.append(checkpoint['epoch'])
            b_mious.append(checkpoint['miou'])
        Parameters["BEST_MIOUS"] = b_mious
        Parameters["BEST_EPOCHS"] = b_epochs

        print()
        print("loaded_best_mious: " + str(Parameters["BEST_MIOUS"]))
        print("loaded_best_epochs: " + str(Parameters["BEST_EPOCHS"]))
        
        # really load the model and optimizer
        LOAD_PATH = os.path.join(Parameters["MODEL_LOAD_PATH"], \
                                             "FUNDAMODEL{}.pth".format(Parameters["MODEL_LOAD_VERSION"]))
        load_checkpoint = torch.load(LOAD_PATH)
        model.load_state_dict(load_checkpoint['model_state_dict'])
        optimizer.load_state_dict(load_checkpoint['optimizer_state_dict'])
        Parameters["START_EPOCH"] = load_checkpoint['epoch']
        print()
        print('Now the model is at epoch {} with miou {}.'.format(load_checkpoint['epoch'], load_checkpoint['miou']))
        # epoch = checkpoint['epoch']
        # miou = checkpoint['miou']
        # Parameters["BEST_EPOCH"] = checkpoint['epoch']
        # Parameters["BEST_MIOU"] = checkpoint['miou']


    trainer = Trainer(model, optimizer, train_loader, val_loader, Parameters)
    result_best_mious, result_best_epochs = trainer.train()
    print()
    print("best_mious: " + str(result_best_mious))
    print("best_epochs: " + str(result_best_epochs))
    
    end_time = time.time()
    epoch_mins, epoch_secs = total_time(start_time, end_time)
    print()
    print(f'Total Time: {epoch_mins}m {epoch_secs}s')     
    #trainer = Trainer(data_loader.loader(), config)
    #trainer.train()

In [None]:
print(Parameters["BEST_EPOCHS"])
print(Parameters['BEST_MIOUS'])

# Testing

In [None]:
# Define model and optimizer
loaded_model = smp.Unet(
        encoder_name="resnet34",
        encoder_weights="imagenet",
        in_channels=3,
        classes=Parameters["NUM_CLASSES"],
    ).to(Parameters["DEVICE"])


epochs, mious = [], []
num = Parameters["EXPECTED_MODEL_NUMBER"]
for i in range(num):
    PATH = os.path.join(Parameters["MODEL_LOAD_PATH"], \
                                          "{}_MODEL{}.pth".format(i, Parameters["MODEL_LOAD_VERSION"]))
    checkpoint = torch.load(PATH)
    # loaded_model.load_state_dict(checkpoint['model_state_dict'])
    epochs.append(checkpoint['epoch'])
    mious.append(checkpoint['miou'])
print('best_epochs: ' + str(epochs))
print('best_mious: ' + str(mious))





In [None]:
LOAD_PATH = os.path.join(Parameters["MODEL_LOAD_PATH"], \
                                      "{}_MODEL{}.pth".format(1, Parameters["MODEL_LOAD_VERSION"]))
load_checkpoint = torch.load(LOAD_PATH)
loaded_model.load_state_dict(load_checkpoint['model_state_dict'])
loaded_epoch = load_checkpoint['epoch']
loaded_miou = load_checkpoint['miou']
print('loaded_epoch: ' + str(loaded_epoch))
print('loaded_miou: ' + str(loaded_miou))


tester = Tester(loaded_model, test_loader, Parameters)
result_list = tester.test()


In [None]:
result_file = open(os.path.join(Parameters["RESULTS_PATH"], 'result_number.txt'),'w+')
result_file.write('total_number: ' + str(len(result_list)))
result_file.write('\n')
result_file.write(str(result_list))
result_file.close()