In [None]:
!pip install -q segmentation_models_pytorch

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from seaborn import color_palette
from PIL import Image
import pandas as pd
from glob import iglob
import re
from datetime import datetime
import random
import cv2
from copy import copy
from transformers import Mask2FormerForUniversalSegmentation
from importlib import reload
from glob import iglob
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from sklearn.manifold import TSNE
from tqdm import tqdm

import torch
import torch.nn.functional as F
import torchvision.transforms as T
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.tensorboard import SummaryWriter
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image

import sys
sys.path.append('../src')
import utils as ut
import dataset as ds


import segmentation_models_pytorch as smp


In [None]:
id2label = ds.IDs('mocamba').id2label
n_classes = len(id2label)
print(id2label)


# long = 256
# long = 512
# long = 1024
long = 1536
# long = 2048

f = long/512
crop_size = (256*f, 352*f) 
crop_size = tuple(int(a) for a in crop_size)
print(f'long: {long}, crop_size: {crop_size}')

T_crop = T.Compose([
    T.RandomCrop(size=crop_size),
    T.RandomHorizontalFlip(p=.5),
])


train_annotation = f'../data/tidyv01-long{long}/trainpaths.txt'
val_annotation = f'../data/tidyv01-long{long}/valpaths.txt'


train_ds = ut.SimpleDataset(annotation_file=train_annotation, transform=T_crop, transform_target=T_crop)
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)

val_ds = ut.SimpleDataset(annotation_file=val_annotation)
val_loader = DataLoader(val_ds, batch_size=1)

In [None]:
# tmp=[]
for inp_im, inp_label in val_loader:
    # tmp.append(inp_label.max())
    break

   

dummy_im = inp_im.clone()[:6]
dummy_label = inp_label.clone()[:6]
dummy_label.unique()
print(dummy_im.shape)

colorizer = ut.TorchColorizer(len(ds.mocamba_classnames)+1)

ims = ut.Normalize.reverse(dummy_im)
labels = colorizer(dummy_label)
alpha = .2
blend = (1-alpha)*ims + alpha*labels

tmp = torch.concat([ims, labels, blend], axis=-2)
tmp = tmp.moveaxis(0,-2).flatten(-2,-1).permute(1,2,0)
tmp = ut.float_to_uint8(tmp.numpy())
Image.fromarray(tmp)

# pre-trained deeplab

In [None]:
model = smp.DeepLabV3Plus(
    encoder_name="resnet50", encoder_weights="imagenet",
    in_channels=3, classes=len(ds.mocamba_classnames)
)

In [None]:
model.eval().cuda();

pred = model(dummy_im.cuda())
print(pred.shape)
pred = pred.argmax(1, keepdim=True).long()
segmentation_map = pred.cpu()

# colorizer = ut.TorchColorizer(len(ds.mocamba_classnames))
colorizer = ut.TorchColorizer(19)
colorizer.colors = colorizer.colors[np.random.permutation(len(colorizer.colors))]

ims = ut.Normalize.reverse(dummy_im)
labels = colorizer(segmentation_map)
alpha = .2
blend = (1-alpha)*ims + alpha*labels

tmp = torch.concat([ims, labels, blend], axis=-2)
tmp = tmp.moveaxis(0,-2).flatten(-2,-1).permute(1,2,0)
tmp = ut.float_to_uint8(tmp.numpy())
Image.fromarray(tmp)

# Training

In [None]:
model = smp.DeepLabV3Plus(
    encoder_name="resnet50", encoder_weights="imagenet",
    in_channels=3, classes=len(ds.mocamba_classnames)
)

model.cuda();

In [None]:
CE = nn.CrossEntropyLoss()

def compute_diceloss(preds, labels, n_classes=len(ds.mocamba_classnames)):
    labels_onehot = F.one_hot(labels[:, 0], num_classes=n_classes).permute(0, 3, 1, 2)
    assert labels_onehot.shape == preds.shape
    nume = (labels_onehot * preds).sum([0, 2, 3]) 
    denom = (labels_onehot + preds).sum([0, 2, 3]).clamp_min(1e-3)
    dice = nume / denom
    return 1 - dice.mean()
    

def criterea(preds, labels):
    ce = CE(preds, labels[:,0])
    
    preds_prob = preds.softmax(1)
    dice = compute_diceloss(preds_prob, labels)
    
    loss = ce + dice
    return loss

model.eval();
outs = model(dummy_im.cuda())
criterea(outs, dummy_label.cuda())

In [None]:
opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
scheduler = ut.WarmupLR(opt, n_warmup_max=1000)
warmup_step = 1


time_now = datetime.now().strftime("%Y%m%d_%H%M")
time_now += f'-PAPER-deeplab-tidyv01+long{long}-NEW-bs8x1'
print(time_now)
logdir = os.path.join('../logs', time_now)
writer = SummaryWriter(log_dir=logdir)
writer.flush()


scaler = torch.cuda.amp.GradScaler()

In [None]:
val_ds = ut.SimpleDataset(annotation_file=f'../data/tidyv01-long{long}/valpaths.txt')
val_loader = DataLoader(val_ds, batch_size=1)

bs_train = 8
train_loader = DataLoader(train_ds, batch_size=bs_train, shuffle=True)

model.cuda();

In [None]:
fl_resume = True
fl_resume = False


if not fl_resume:
    it = 1
    losses = []
    miou_best = 0
    steps_cur = 0
    lossi = []
       
    lr_cur = opt.param_groups[0]["lr"]
    print(f'lr: {lr_cur:.2e}')
    writer.add_scalars('lr', {'nn': lr_cur}, steps_cur)

    model.eval()
    miou_train = ut.get_CM_fromloader(train_loader, model, n_classes, deeplab=True)[0]
    miou_val = ut.get_CM_fromloader(val_loader, model, n_classes, deeplab=True)[0]
    print(f'\tmiou-train: {miou_train:.2f}\tmiou-val: {miou_val:.2f}')
    writer.add_scalars('miou', {'train': miou_train}, steps_cur)
    writer.add_scalars('miou', {'val': miou_val}, steps_cur)


print_every_n = 250
iters_total = 10000
n_gradacc = max(8//bs_train,1) # bs=8
iters_total *= n_gradacc
print_every_n *= n_gradacc
print(f'grad-acc: {n_gradacc}')
print(f'iters_total: {iters_total}')


model.train()
opt.zero_grad(set_to_none=True)
while it <= iters_total:
    
    for inp_im, inp_label in train_loader:
        inp_im = inp_im.cuda()
        inp_label = inp_label.cuda()
        
        # Only forward pass goes into the autocast context
        with torch.cuda.amp.autocast(dtype=torch.float16, enabled=False):
            pred = model(inp_im.cuda())
            loss = criterea(pred, inp_label)
            
        scaler.scale(loss).backward()
        # Accumulate gradients
        if it % n_gradacc == 0:
            scaler.step(opt)
            scaler.update()
            opt.zero_grad(set_to_none=True)            
            
        
        lossi.append(loss.item())
        it += 1
        
        if scheduler.fl_warmup and it % warmup_step == 0:
            scheduler.step()
            # print(f'lr: {opt.param_groups[0]["lr"]:.2e}')
        
        if it % print_every_n == 0:
            steps_cur += 1

            loss_avg = np.mean(lossi)
            losses.append(loss_avg)

            model.eval()
            miou_train = ut.get_CM_fromloader(train_loader, model, n_classes, deeplab=True)[0]
            miou_val = ut.get_CM_fromloader(val_loader, model, n_classes, deeplab=True)[0]
            model.train()

            print(
                f'it: {it}\tloss: {loss_avg:.4f}'
                f'\tmiou-train: {miou_train:.3f}',
                f'\tmiou-val: {miou_val:.3f}'
            )

            writer.add_scalars('objective', {'train': loss_avg}, steps_cur)
            writer.add_scalars('miou', {'train': miou_train}, steps_cur)
            writer.add_scalars('miou', {'val': miou_val}, steps_cur)

            
            lossi = []
            
            ckpt_last_path = os.path.join(logdir, 'model.last.pth')
            ut.save_ckpt(ckpt_last_path, model, opt, miou_val, it)

            if miou_val > miou_best:
                miou_best = miou_val
                ckpt_last_path = os.path.join(logdir, 'model.best.pth')
                ut.save_ckpt(ckpt_last_path, model, opt, miou_val, it)

            if it in [5000, 9000]:
                for pg in opt.param_groups:
                    pg['lr'] *= .1
                    
        if it > iters_total:
            break

In [None]:
# for pg in opt.param_groups:
#     print(pg['lr'])
#     pg['lr'] *= .1
#     print(pg['lr'])
    

In [None]:
model.eval()
pred = get_hrnet_segmap(dummy_im).cpu()
ims = ut.Normalize.reverse(dummy_im)
# labels = colorizer(dummy_label)
labels = colorizer(pred)
alpha = .2
blend = (1-alpha)*ims + alpha*labels
tmp = torch.concat([ims, labels, blend], axis=-2)
tmp = tmp.moveaxis(0,-2).flatten(-2,-1).permute(1,2,0)
tmp = ut.float_to_uint8(tmp.numpy())
Image.fromarray(tmp)