In [None]:
# !git clone https://github.com/tldrafael/DeepLabV3Plus-Pytorch
# !mv DeepLabV3Plus-Pytorch deeplabv3

# !pip install git+https://github.com/huggingface/transformers.git
# !git clone https://github.com/HRNet/HRNet-Semantic-Segmentation
# wget https://github.com/hsfzxjy/models.storage/releases/download/HRNet-OCR/hrnet_ocr_cs_trainval_8227_torch11.pth
# wget https://github.com/hsfzxjy/models.storage/releases/download/HRNet-OCR/hrnet_ocr_ade20k_4451_torch04.pth

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from seaborn import color_palette
from PIL import Image
import pandas as pd
from glob import iglob
import re
from datetime import datetime
import random
import cv2
from copy import copy
from transformers import Mask2FormerForUniversalSegmentation
from importlib import reload
from glob import iglob
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from sklearn.manifold import TSNE
from tqdm import tqdm

import torch
import torch.nn.functional as F
import torchvision.transforms as T
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.tensorboard import SummaryWriter
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image

import sys
sys.path.append('../src')
import utils as ut
import dataset as ds

In [None]:
sys.path.append("../../HRNet-Semantic-Segmentation/lib/")
from config import update_config, config
import models
from core.criterion import CrossEntropy, OhemCrossEntropy

In [None]:
from types import SimpleNamespace
args = SimpleNamespace()
args.cfg = '../../HRNet-Semantic-Segmentation/experiments/ade20k/seg_hrnet_ocr_w48_520x520_ohem_sgd_lr2e-2_wd1e-4_bs_16_epoch120.yaml'
args.opts = [
    'TEST.MODEL_FILE', '../../HRNet-Semantic-Segmentation/pretrained_models/hrnet_ocr_cs_trainval_8227_torch11.pth',
    'DATASET.TEST_SET', 'list/ade20k/testval.lst',
    'DATASET.NUM_CLASSES', len(ds.mocamba_classnames),
    # 'DATASET.NUM_CLASSES', 150,
]



In [None]:
update_config(config, args)
print(config.TEST.MODEL_FILE, config.DATASET.TEST_SET, config.DATASET.NUM_CLASSES)

config['MODEL']['PRETRAINED'] = ''

In [None]:

@torch.no_grad()
def get_hrnet_segmap(inp):
    model.eval()
    pred = model(inp.cuda())[config.TEST.OUTPUT_INDEX]
    pred = F.interpolate(pred, size=inp.shape[-2:], mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS)    
    pred = pred.argmax(1, keepdim=True).long()
    return pred


def get_CM_fromloader(
    dloader, model, n_classes, ix_nolabel=255, filling_signs=False
):
    CM_abs = np.zeros((n_classes, n_classes), dtype=int)
    
    for inp_im, inp_label in dloader:
        test_preds = get_hrnet_segmap(inp_im)
        for pr_i, y_i in zip(test_preds.cpu(), inp_label.cpu()):
            CM_abs += ut.get_CM(pr_i, y_i, n_classes)
            
    pred_P = CM_abs.sum(axis=0)
    gt_P = CM_abs.sum(axis=1)
    true_P = np.diag(CM_abs)

    CM_iou = true_P / (pred_P + gt_P - true_P)
    miou = np.nanmean(CM_iou)
    return miou, CM_iou, CM_abs


def adjust_learning_rate(opt, base_lr, max_iters, cur_iters, power=0.9, nbb_mult=10):
    factor = (1 - cur_iters / max_iters) ** power
    lr = base_lr * factor
    opt.param_groups[0]['lr'] = lr
    if len(opt.param_groups) == 2:
        opt.param_groups[1]['lr'] = lr * nbb_mult
    return lr

In [None]:
hrnet = models.seg_hrnet_ocr
hrnet.BatchNorm2d_class = hrnet.BatchNorm2d = nn.BatchNorm2d
hrnet = hrnet.get_seg_model(config)

In [None]:
hrnet

In [None]:
raise 'kk'

In [None]:
state_dict = torch.load('../../HRNet-Semantic-Segmentation/pretrained_models/hrnet_ocr_ade20k_4451_torch04.pth')
model_dict = hrnet.state_dict()
state_dict = {
    k[6:]: v for k, v in state_dict.items()
    if k[6:] in model_dict.keys() and not re.search('(cls_head\.|aux_head\.)', k)
}
model_dict.update(state_dict)
hrnet.load_state_dict(model_dict, strict=False);

In [None]:
id2label = ds.IDs('mocamba').id2label
n_classes = len(id2label)
print(id2label)


long = 256
# long = 512
# long = 1024
# long = 1536
# long = 2048

f = long/512
crop_size = (256*f, 352*f) 
crop_size = tuple(int(a) for a in crop_size)
print(f'long: {long}, crop_size: {crop_size}')

T_crop = T.Compose([
    T.RandomCrop(size=crop_size),
    T.RandomHorizontalFlip(p=.5)
])


train_annotation = f'../data/tidyv01-long{long}/trainpaths.txt'
val_annotation = f'../data/tidyv01-long{long}/valpaths.txt'


train_ds = ut.SimpleDataset(annotation_file=train_annotation, transform=T_crop, transform_target=T_crop)
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)

val_ds = ut.SimpleDataset(annotation_file=val_annotation)
val_loader = DataLoader(val_ds, batch_size=1)

In [None]:
# tmp=[]
for inp_im, inp_label in train_loader:
    # tmp.append(inp_label.max())
    break

   

dummy_im = inp_im.clone()[:6]
dummy_label = inp_label.clone()[:6]
dummy_label.unique()
print(dummy_im.shape)

colorizer = ut.TorchColorizer(len(ds.mocamba_classnames)+1)

ims = ut.Normalize.reverse(dummy_im)
labels = colorizer(dummy_label)
alpha = .2
blend = (1-alpha)*ims + alpha*labels

tmp = torch.concat([ims, labels, blend], axis=-2)
tmp = tmp.moveaxis(0,-2).flatten(-2,-1).permute(1,2,0)
tmp = ut.float_to_uint8(tmp.numpy())
Image.fromarray(tmp)

# pre-trained HRNet

In [None]:
# trainpaths = list(iglob('../HRNet-Semantic-Segmentation/data/train/image/*'))
# print(len(trainpaths))
# p = np.random.choice(trainpaths)
# print(p)

# im = cv2.imread(p)[...,::-1]
# # im = im1
# inp = torch.tensor(im.copy()).permute(2,0,1)[None]/255
# inp = T.functional.resize(inp, (384*1,512*1))
# inp = ut.Normalize.forward(inp)
# inp.shape
# dummy_im = inp

In [None]:
# dummy_im = torch.load('/home/rafael/a.pt')
# hrnet = torch.load('/home/rafael/model.pt')

In [None]:
hrnet.eval().cuda();
model = hrnet

pred = model(dummy_im.cuda())[config.TEST.OUTPUT_INDEX]
print(pred.shape)
pred = F.interpolate(pred, size=dummy_im.shape[-2:], mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS)    
print(pred.shape)
pred = pred.argmax(1, keepdim=True).long()
segmentation_map = pred.cpu()

colorizer = ut.TorchColorizer(len(ds.mocamba_classnames))
# colorizer = ut.TorchColorizer(150)
colorizer.colors = colorizer.colors[np.random.permutation(len(colorizer.colors))]

ims = ut.Normalize.reverse(dummy_im)
labels = colorizer(segmentation_map)
alpha = .2
blend = (1-alpha)*ims + alpha*labels

tmp = torch.concat([ims, labels, blend], axis=-2)
tmp = tmp.moveaxis(0,-2).flatten(-2,-1).permute(1,2,0)
tmp = ut.float_to_uint8(tmp.numpy())
Image.fromarray(tmp)

In [None]:
pred1 = cv2.imread('../../HRNet-Semantic-Segmentation/test_results/ADE_val_00000001.png', cv2.IMREAD_UNCHANGED)
im1 = cv2.imread('../../HRNet-Semantic-Segmentation/data/val/image/ADE_val_00000001.jpg', cv2.IMREAD_UNCHANGED)[..., ::-1]
gt1 = cv2.imread('../../HRNet-Semantic-Segmentation/data/val/label/ADE_val_00000001.png', cv2.IMREAD_UNCHANGED)
print(pred1.shape, im1.shape, gt1.shape)

In [None]:
plt.subplot(131)
plt.imshow(im1)
plt.subplot(132)
plt.imshow(gt1)
plt.subplot(133)
plt.imshow(pred1)

# Training

In [None]:

print(config.LOSS.USE_OHEM)


criterion = OhemCrossEntropy(
    ignore_label=config.TRAIN.IGNORE_LABEL,
    thres=config.LOSS.OHEMTHRES,
    min_kept=config.LOSS.OHEMKEEP,
    weight=None
)


hrnet = models.seg_hrnet_ocr
hrnet.BatchNorm2d_class = hrnet.BatchNorm2d = nn.BatchNorm2d
hrnet = hrnet.get_seg_model(config)

state_dict = torch.load('../../HRNet-Semantic-Segmentation/pretrained_models/hrnet_ocr_ade20k_4451_torch04.pth')
model_dict = hrnet.state_dict()
state_dict = {
    k[6:]: v for k, v in state_dict.items()
    if k[6:] in model_dict.keys() and not re.search('(cls_head\.|aux_head\.)', k)
}
model_dict.update(state_dict)
hrnet.load_state_dict(model_dict, strict=False);

model = hrnet

In [None]:
# print(config.TRAIN.OPTIMIZER, config.TRAIN.NONBACKBONE_KEYWORDS, config.TRAIN.LR,
#       config.TRAIN.MOMENTUM, config.TRAIN.WD, config.TRAIN.NESTEROV)

# opt = torch.optim.SGD(
#     model.parameters(), lr=config.TRAIN.LR, momentum=config.TRAIN.MOMENTUM,
#     weight_decay=config.TRAIN.WD, nesterov=config.TRAIN.NESTEROV)


opt = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [None]:

scheduler = ut.WarmupLR(opt, n_warmup_max=1000)
warmup_step = 1

time_now = datetime.now().strftime("%Y%m%d_%H%M")
time_now += f'-PAPER-HRNet-tidyv01+long{long}'
print(time_now)
logdir = os.path.join('../logs', time_now)
writer = SummaryWriter(log_dir=logdir)
writer.flush()


scaler = torch.cuda.amp.GradScaler()

In [None]:
val_ds = ut.SimpleDataset(annotation_file=f'../data/tidyv01-long{long}/valpaths.txt')
val_loader = DataLoader(val_ds, batch_size=1)

bs_train = 8
train_ds = ut.SimpleDataset(annotation_file=f'../data/tidyv01-long{long}/trainpaths.txt')
train_loader = DataLoader(train_ds, batch_size=bs_train, shuffle=True)

model.cuda();

In [None]:
fl_resume = True
fl_resume = False


if not fl_resume:
    it = 1
    losses = []
    miou_best = 0
    steps_cur = 0
    lossi = []
       
    lr_cur = opt.param_groups[0]["lr"]
    print(f'lr: {lr_cur:.2e}')
    writer.add_scalars('lr', {'nn': lr_cur}, steps_cur)

    model.eval()
    miou_train = get_CM_fromloader(train_loader, model, n_classes)[0]
    miou_val = get_CM_fromloader(val_loader, model, n_classes)[0]
    print(f'\tmiou-train: {miou_train:.2f}\tmiou-val: {miou_val:.2f}')
    writer.add_scalars('miou', {'train': miou_train}, steps_cur)
    writer.add_scalars('miou', {'val': miou_val}, steps_cur)


print_every_n = 250
iters_total = 10000
n_gradacc = max(8//bs_train,1) # bs=8
iters_total *= n_gradacc
print_every_n *= n_gradacc
print(f'grad-acc: {n_gradacc}')
print(f'iters_total: {iters_total}')


model.train()
opt.zero_grad(set_to_none=True)
while it <= iters_total:
    
    for inp_im, inp_label in train_loader:
        inp_im = inp_im.cuda()
        inp_label = inp_label[:, 0].cuda()
        
        # Only forward pass goes into the autocast context
        with torch.cuda.amp.autocast(dtype=torch.float16, enabled=False):
            pred = model(inp_im.cuda())
            loss = criterion(pred, inp_label.cuda()).mean()        
        scaler.scale(loss).backward()
        # Accumulate gradients
        if it % n_gradacc == 0:
            scaler.step(opt)
            scaler.update()
            opt.zero_grad(set_to_none=True)            
            
        
        lossi.append(loss.item())
        it += 1
        
        if scheduler.fl_warmup and it % warmup_step == 0:
            scheduler.step()
            # print(f'lr: {opt.param_groups[0]["lr"]:.2e}')
        
        if it % print_every_n == 0:
            steps_cur += 1

            loss_avg = np.mean(lossi)
            losses.append(loss_avg)

            model.eval()
            miou_train = get_CM_fromloader(train_loader, model, n_classes)[0]
            miou_val = get_CM_fromloader(val_loader, model, n_classes)[0]
            model.train()

            print(
                f'it: {it}\tloss: {loss_avg:.4f}'
                f'\tmiou-train: {miou_train:.3f}',
                f'\tmiou-val: {miou_val:.3f}'
            )

            writer.add_scalars('objective', {'train': loss_avg}, steps_cur)
            writer.add_scalars('miou', {'train': miou_train}, steps_cur)
            writer.add_scalars('miou', {'val': miou_val}, steps_cur)

            
            lossi = []
            
            ckpt_last_path = os.path.join(logdir, 'model.last.pth')
            ut.save_ckpt(ckpt_last_path, model, opt, miou_val, it)

            if miou_val > miou_best:
                miou_best = miou_val
                ckpt_last_path = os.path.join(logdir, 'model.best.pth')
                ut.save_ckpt(ckpt_last_path, model, opt, miou_val, it)

            if it in [5000, 9000]:
                for pg in opt.param_groups:
                    pg['lr'] *= .1
                    
        if it > iters_total:
            break

In [None]:
# for pg in opt.param_groups:
#     print(pg['lr'])
#     pg['lr'] *= .1
#     print(pg['lr'])
    

In [None]:
model.eval()
pred = get_hrnet_segmap(dummy_im).cpu()
ims = ut.Normalize.reverse(dummy_im)
# labels = colorizer(dummy_label)
labels = colorizer(pred)
alpha = .2
blend = (1-alpha)*ims + alpha*labels
tmp = torch.concat([ims, labels, blend], axis=-2)
tmp = tmp.moveaxis(0,-2).flatten(-2,-1).permute(1,2,0)
tmp = ut.float_to_uint8(tmp.numpy())
Image.fromarray(tmp)

# Evaluating model

In [None]:
n_classes = len(ds.mocamba_classnames)

hrnet = models.seg_hrnet_ocr
hrnet.BatchNorm2d_class = hrnet.BatchNorm2d = nn.BatchNorm2d
hrnet = hrnet.get_seg_model(config)

# modelpath = '../logs/20240619_1547-hrnet-res512/model.best.pth'
modelpath = '../logs/20240619_2345-hrnet-res1024-adam/model.best.pth'


state_dict = torch.load(modelpath)['state']
hrnet.load_state_dict(state_dict);
hrnet.eval().cuda()
model = hrnet

In [None]:
miou, CM_iou, CM_abs = get_CM_fromloader(val_loader, model, n_classes)
print(miou.round(2))
print(CM_iou.round(2))

In [None]:
# mocamba_classnames = {
#     0: 'background', 1: 'thing-Animals', 2: 'surface-Asphalt', 3: 'sign-Cat-s-Eye', 4: 'damage-Cracks', 5: 'thing-Ego',
#     6: 'surface-Hard-Sand', 7: 'sign-Markings', 8: 'thing-Obstacle', 9: 'thing-People', 10: 'damage-Pothole', 11: 'thing-Retaining-wall',
#     12: 'surface-Soft-Sand', 13: 'surface-Unpaved', 14: 'thing-Vehicles', 15: 'sign-Vertical-Signs', 16: 'surface-Wet-sand'
# }


txt = ''
for v, p in zip(ds.mocamba_classnames.values(), CM_iou):
    txt += f'{v}:\t{round(p*100,2)}\n'

print(txt)

In [None]:
# fpath = 'all-trainpaths-wider.txt'
# fpath = 'all-valpaths-wider.txt'
# fpath = '../data/mocamba/ds-mocamba-v0.3.4-long1024/trainpaths.txt'
fpath = '../data/mocamba/ds-mocamba-v0.3.4-long1024/valpaths.txt'

colorizer = ut.TorchColorizer(len(ds.mocamba_classnames))


val_ds = ut.SimpleDataset(annotation_file=fpath)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=True)
hrnet.eval()
alpha = .2
nplot = 10

plot_ims = []
for inp_im, inp_label in val_loader:
    pred = get_hrnet_segmap(inp_im).cpu()
    im = ut.Normalize.reverse(inp_im)
    label = colorizer(inp_label)
    
    pred = colorizer(pred)
    blend = (1-alpha)*im + alpha*pred

    tmp = [im, label, pred, blend]
    tmp = [F.pad(x, (0, 1024-x.shape[-1], 0, 0, 0, 0, 0, 0)) for x in tmp]

    tmp = torch.concat([torch.concat(tmp, axis=-1)], axis=-2)
    tmp = tmp.moveaxis(0,-2).flatten(-2,-1).permute(1,2,0)
    tmp = ut.float_to_uint8(tmp.numpy())
    plot_ims.append(tmp)

    if len(plot_ims) >= nplot:
        break


tmp = np.concatenate(plot_ims, axis=0)
Image.fromarray(tmp)