In [None]:

#to get jupyter notebook to work
# pip install jupyter
# pip install jupyter_http_over_ws
# jupyter notebook --NotebookApp.allow_origin='https://colab.research.google.com' --port=8888 --NotebookApp.port_retries=0
# then use the url it gives you to open the notebook in a new tab and run the cells


!pip install matplotlib timm==0.9.5 tqdm scipy numpy tensorboardX wget scikit-image scikit-learn xformers
!pip install git+https://github.com/lucasb-eyer/pydensecrf.git

In [None]:
# if using google colab 
from google.colab import drive
drive.mount('/content/drive')
import os
os.chdir('/content/drive/Shareddrives/ARP_Hyperspectral_Algorithms/Sample_Data')

In [None]:
# if using local jupyter notebook
import os
print(os.getcwd())
directories = [d for d in os.listdir('.') if os.path.isdir(d)]
print(directories)

In [1]:
#crop data

import os
import torch
import argparse
from PIL import Image
from os.path import join
from utils.utils import *
from torch.utils.data import DataLoader
from loader.dataloader import ContrastiveSegDataset, CroppedDataset
from torchvision.transforms.functional import five_crop, ten_crop
from tqdm import tqdm
from torch.utils.data import Dataset
from torchvision import transforms as T

class RandomCropComputer(Dataset):

    @staticmethod
    def _get_size(img, crop_ratio):
        if len(img.shape) == 3:
            return [int(img.shape[1] * crop_ratio), int(img.shape[2] * crop_ratio)]
        elif len(img.shape) == 2:
            return [int(img.shape[0] * crop_ratio), int(img.shape[1] * crop_ratio)]
        else:
            raise ValueError("Bad image shape {}".format(img.shape))

    def __init__(self, args, dataset_name, img_set, crop_type, crop_ratio):
        self.pytorch_data_dir = args.data_dir
        self.crop_ratio = crop_ratio

        if crop_type == 'five':
            crop_func = lambda x: five_crop(x, self._get_size(x, crop_ratio))
        elif crop_type == 'double':
            crop_ratio = 0
            crop_func = lambda x: ten_crop(x, self._get_size(x, 0.5))\
                                + ten_crop(x, self._get_size(x, 0.8))
        elif crop_type == 'super':
            crop_ratio = 0
            crop_func = lambda x: ten_crop(x, self._get_size(x, 0.3))\
                                + ten_crop(x, self._get_size(x, 0.4))\
                                + ten_crop(x, self._get_size(x, 0.5))\
                                + ten_crop(x, self._get_size(x, 0.6))\
                                + ten_crop(x, self._get_size(x, 0.7))

        if args.dataset=='coco171':
            self.save_dir = join(
                args.data_dir, 'cocostuff', "cropped", "coco171_{}_crop_{}".format(crop_type, crop_ratio))
        elif args.dataset=='coco81':
            self.save_dir = join(
                args.data_dir, 'cocostuff', "cropped", "coco81_{}_crop_{}".format(crop_type, crop_ratio))
        else:
            self.save_dir = join(
                args.data_dir, dataset_name, "cropped", "{}_{}_crop_{}".format(dataset_name, crop_type, crop_ratio))
        self.args = args

        self.img_dir = join(self.save_dir, "img", img_set)
        self.label_dir = join(self.save_dir, "label", img_set)
        os.makedirs(self.img_dir, exist_ok=True)
        os.makedirs(self.label_dir, exist_ok=True)

        # train dataset
        # print("Loading dataset {}...".format(dataset_name))
        # print("Crop type: {}".format(crop_type))
        # print("Crop ratio: {}".format(crop_ratio))
        # print("data dir: {}".format(args.data_dir))
              
        self.dataset = ContrastiveSegDataset(
            pytorch_data_dir=args.data_dir,
            dataset_name=args.dataset,
            crop_type=crop_type,
            image_set=img_set,
            transform=T.ToTensor(),
            target_transform=ToTargetTensor(),
            extra_transform=crop_func
        )
        
    def __getitem__(self, item):
        return self.dataset[item]

    def __len__(self):
        return len(self.dataset)


def my_app():

    #  note that in the dataloader.py contrastivesegdataset class, you need to hard code some struff right now if using a custom dataset

    # fetch args
    parser = argparse.ArgumentParser()

    # fixed parameter
    parser.add_argument('--num_workers', default=int(os.cpu_count() / 8), type=int)

    # dataset and baseline
    parser.add_argument('--data_dir', default='../', type=str)
    parser.add_argument('--dataset', default='freiburg', type=str)
    parser.add_argument('--gpu', default=0, type=int)
    parser.add_argument('--distributed', default='false', type=str2bool)
    parser.add_argument('--crop_type', default='five', type=str)
    parser.add_argument('--crop_ratio', default=0.5, type=float)

    args = parser.parse_args(args=[])
    
    # setting gpu id of this process
    torch.cuda.set_device(args.gpu)

    counter = 0
    dataset = RandomCropComputer(args, args.dataset, "train", args.crop_type, args.crop_ratio)
    loader = DataLoader(dataset, 1, shuffle=False, num_workers=args.num_workers, collate_fn=lambda l: l)
    for batch in tqdm(loader):
        imgs = batch[0]['img']
        # print('here')
        labels = batch[0]['label']
        for img, label in zip(imgs, labels):
            img_arr = img.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
            label_arr = (label + 1).unsqueeze(0).permute(1, 2, 0).to('cpu', torch.uint8).numpy().squeeze(-1)
            Image.fromarray(img_arr).save(join(dataset.img_dir, "{}.jpg".format(counter)), 'JPEG')
            Image.fromarray(label_arr).save(join(dataset.label_dir, "{}.png".format(counter)), 'PNG')
            counter+=1

if __name__ == "__main__":
    my_app()

100%|██████████| 366/366 [00:05<00:00, 63.53it/s]


In [2]:
# train mediator python file 

import argparse

from tqdm import tqdm
from utils.utils import *
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn
from modules.segment_module import compute_modularity_based_codebook
from loader.dataloader import dataloader
from loader.netloader import network_loader, cluster_mlp_loader
from torch.cuda.amp import autocast, GradScaler
# from loader.netloader import network_loader, cluster_mlp_loader

cudnn.benchmark = True
scaler = GradScaler()

def ddp_setup(args, rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = args.port

    # initialize
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def ddp_clean():
    dist.destroy_process_group()

@Wrapper.EpochPrint
def train(args, net, cluster, train_loader, optimizer):
    prog_bar = tqdm(enumerate(train_loader), total=len(train_loader), leave=True)
    for idx, batch in prog_bar:
        # image and label and self supervised feature
        img = batch["img"].cuda()

        # intermediate feature
        with autocast():
            feat = net(img)[:, 1:, :]

            # computing modularity based codebook
            loss_mod = compute_modularity_based_codebook(cluster.codebook, feat, grid=args.grid)

        # optimization
        optimizer.zero_grad()
        scaler.scale(loss_mod).backward()
        scaler.step(optimizer)
        scaler.update()

        # real-time print
        desc = f'[Train]'
        prog_bar.set_description(desc, refresh=True)

        # Interrupt for sync GPU Process
        if args.distributed: dist.barrier()

def main(rank, args, ngpus_per_node):
    # setup ddp process
    if args.distributed: ddp_setup(args, rank, ngpus_per_node)

    # setting gpu id of this process
    torch.cuda.set_device(rank)

    # print argparse
    print_argparse(args, rank)

    # dataset loader
    train_loader, _, sampler = dataloader(args)

    # network loader
    net = network_loader(args, rank)
    cluster = cluster_mlp_loader(args, rank)

    # distributed parsing
    if args.distributed: net = net.module; cluster = cluster.module

    # optimizer and scheduler
    optimizer = torch.optim.Adam(cluster.parameters(), lr=1e-3 * ngpus_per_node)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.2)

    ###################################################################################
    # train only modularity?
    path, is_exist = pickle_path_and_exist(args)

    # early save for time
    if not is_exist:
        rprint("No File Exists!!", rank)
        # train
        for epoch in range(args.epoch):

            # for shuffle
            if args.distributed: sampler.set_epoch(epoch)

            # train
            train(
                epoch,  # for decorator
                rank,  # for decorator
                args,
                net,
                cluster,
                train_loader,
                optimizer)

            # scheduler step
            scheduler.step()

            # save
            if rank == 0:
                np.save(path, cluster.codebook.detach().cpu().numpy()
                if args.distributed else cluster.codebook.detach().cpu().numpy())

            # Interrupt for sync GPU Process
            if args.distributed: dist.barrier()

    else:
        rprint("Already Exists!!", rank)
    ###################################################################################


    # clean ddp process
    if args.distributed: ddp_clean()


if __name__ == "__main__":


    #  note that in the dataloader.py dataloader class, you need to hard code some struff right now if using a custom dataset

    # fetch args
    parser = argparse.ArgumentParser()

    # fixed parameter
    parser.add_argument('--epoch', default=1, type=int)
    # parser.add_argument('--distributed', default=True, type=str2bool)
    parser.add_argument('--distributed', default=False, type=str2bool)

    parser.add_argument('--load_segment', default=False, type=str2bool)
    parser.add_argument('--load_cluster', default=False, type=str2bool)
    parser.add_argument('--train_resolution', default=320, type=int)
    parser.add_argument('--test_resolution', default=320, type=int)
    parser.add_argument('--batch_size', default=16, type=int)
    parser.add_argument('--num_workers', default=int(os.cpu_count() / 8), type=int)

    # dataset and baseline
    parser.add_argument('--data_dir', default='../', type=str)
    # parser.add_argument('--dataset', default='cocostuff27', type=str)
    parser.add_argument('--dataset', default='freiburg', type=str)

    parser.add_argument('--ckpt', default='checkpoint/dino_vit_base_8.pth', type=str)

    # DDP
    # parser.add_argument('--gpu', default='0,1,2,3', type=str)
    parser.add_argument('--gpu', default='0', type=str)

    parser.add_argument('--port', default='12355', type=str)

    # parameter
    parser.add_argument('--grid', default='yes', type=str2bool)
    parser.add_argument('--num_codebook', default=2048, type=int)

    # model parameter
    parser.add_argument('--reduced_dim', default=90, type=int)
    parser.add_argument('--projection_dim', default=2048, type=int)

    args = parser.parse_args(args=[])

    if 'dinov2' in args.ckpt:
        args.train_resolution=322
        args.test_resolution=322
    if 'small' in args.ckpt:
        args.dim=384
    elif 'base' in args.ckpt:
        args.dim=768

    # the number of gpus for multi-process
    gpu_list = list(map(int, args.gpu.split(',')))
    ngpus_per_node = len(gpu_list)

    if args.distributed:
        # cuda visible devices
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
        # multiprocess spawn
        mp.spawn(main, args=(args, ngpus_per_node), nprocs=ngpus_per_node, join=True)
    else:
        # first gpu index is activated once there are several gpu in args.gpu
        main(rank=gpu_list[0], args=args, ngpus_per_node=1)

------------------Configurations------------------
epoch: 1
distributed: False
load_segment: False
load_cluster: False
train_resolution: 320
test_resolution: 320
batch_size: 16
num_workers: 2
data_dir: ../
dataset: freiburg
ckpt: checkpoint/dino_vit_base_8.pth
gpu: 0
port: 12355
grid: True
num_codebook: 2048
reduced_dim: 90
projection_dim: 2048
dim: 768
-------------------------------------------------
_IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=[])
Already Exists!!


In [5]:
# train front door tr python file
import argparse

import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)


from tqdm import tqdm
from utils.utils import *
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn
from modules.segment_module import stochastic_sampling, ema_init, ema_update
from loader.dataloader import dataloader
from torch.cuda.amp import autocast, GradScaler
from loader.netloader import network_loader, segment_tr_loader, cluster_tr_loader
from tensorboardX import SummaryWriter

cudnn.benchmark = True
scaler = GradScaler()

# tensorboard
counter = 0
counter_test = 0

def ddp_setup(args, rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = args.port

    # initialize
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def ddp_clean():
    dist.destroy_process_group()


@Wrapper.EpochPrint
def train(args, net, segment, cluster, train_loader, optimizer_segment, writer, rank):
    global counter
    segment.train()

    total_acc = 0
    total_loss = 0
    total_loss_front = 0
    total_loss_linear = 0

    prog_bar = tqdm(enumerate(train_loader), total=len(train_loader), leave=True)
    for idx, batch in prog_bar:

        # optimizer
        with autocast():

            # image and label and self supervised feature
            img = batch["img"].cuda()
            label = batch["label"].cuda()

            # intermediate features
            feat = net(img)[:, 1:, :]
            
            ######################################################################
            # teacher
            seg_feat_ema = segment.head_ema(feat, drop=segment.dropout)
            proj_feat_ema = segment.projection_head_ema(seg_feat_ema)
            ######################################################################

            ######################################################################
            # student
            seg_feat = segment.head(feat, drop=segment.dropout)
            proj_feat = segment.projection_head(seg_feat)
            ######################################################################

            ######################################################################
            # grid
            if args.grid:
                feat, order = stochastic_sampling(feat)
                proj_feat, _ = stochastic_sampling(proj_feat, order=order)
                proj_feat_ema, _ = stochastic_sampling(proj_feat_ema, order=order)
            ######################################################################

            ######################################################################
            # bank compute and contrastive loss
            cluster.bank_compute()
            loss_front = cluster.contrastive_ema_with_codebook_bank(feat, proj_feat, proj_feat_ema)
            ######################################################################

            # linear probe loss
            linear_logits = segment.linear(seg_feat_ema)
            linear_logits = F.interpolate(linear_logits, label.shape[-2:], mode='bilinear', align_corners=False)
            flat_linear_logits = linear_logits.permute(0, 2, 3, 1).reshape(-1, args.n_classes)
            flat_label = label.reshape(-1)
            flat_label_mask = (flat_label >= 0) & (flat_label < args.n_classes)
            loss_linear = F.cross_entropy(flat_linear_logits[flat_label_mask], flat_label[flat_label_mask])

            # loss
            loss = loss_front + loss_linear

        # optimizer
        optimizer_segment.zero_grad()
        scaler.scale(loss).backward()
        if args.dataset=='cityscapes':
            scaler.unscale_(optimizer_segment)
            torch.nn.utils.clip_grad_norm_(segment.parameters(), 1)
        elif args.dataset=='cocostuff27':
            scaler.unscale_(optimizer_segment)
            torch.nn.utils.clip_grad_norm_(segment.parameters(), 2)
        else:
            # raise NotImplementedError
            scaler.unscale_(optimizer_segment)
            torch.nn.utils.clip_grad_norm_(segment.parameters(), 2) #set to two since it is the default 
        scaler.step(optimizer_segment)
        scaler.update()

        # ema update
        ema_update(segment.head, segment.head_ema)
        ema_update(segment.projection_head, segment.projection_head_ema)

        # bank update
        cluster.bank_update(feat, proj_feat_ema)

        # linear probe acc check
        pred_label = linear_logits.argmax(dim=1)
        flat_pred_label = pred_label.view(-1)
        acc = (flat_pred_label[flat_label_mask] == flat_label[flat_label_mask]).sum() / flat_label[
            flat_label_mask].numel()
        total_acc += acc.item()

        # loss check
        total_loss += loss.item()
        total_loss_front += loss_front.item()
        total_loss_linear += loss_linear.item()

        # real-time print
        desc = f'[Train] Loss: {total_loss / (idx + 1):.2f}={total_loss_front / (idx + 1):.2f}+{total_loss_linear / (idx + 1):.2f}'
        desc += f' ACC: {100. * total_acc / (idx + 1):.1f}%'
        prog_bar.set_description(desc, refresh=True)


        # tensorboard
        if (args.distributed == True) and (rank == 0):
            writer.add_scalar('Train/Contrastive', loss_front, counter)
            writer.add_scalar('Train/Linear', loss_linear, counter)
            writer.add_scalar('Train/Acc', total_acc / (idx + 1), counter)
            counter += 1

        # Interrupt for sync GPU Process
        if args.distributed: dist.barrier()


@Wrapper.TestPrint
def test(args, net, segment, nice, test_loader):
    global counter_test
    segment.eval()

    total_acc = 0
    prog_bar = tqdm(enumerate(test_loader), total=len(test_loader), leave=True)
    for idx, batch in prog_bar:
        # image and label and self supervised feature
        img = batch["img"].cuda()
        label = batch["label"].cuda()

        # intermediate feature
        with autocast():
            feat = net(img)[:, 1:, :]
            seg_feat_ema = segment.head_ema(feat)

            # linear probe loss
            linear_logits = segment.linear(seg_feat_ema)
            linear_logits = F.interpolate(linear_logits, label.shape[-2:], mode='bilinear', align_corners=False)
            flat_label = label.view(-1)
            flat_label_mask = (flat_label >= 0) & (flat_label < args.n_classes)

        # linear probe acc check
        pred_label = linear_logits.argmax(dim=1)
        flat_pred_label = pred_label.view(-1)
        acc = (flat_pred_label[flat_label_mask] == flat_label[flat_label_mask]).sum() / flat_label[
            flat_label_mask].numel()
        total_acc += acc.item()

        # real-time print
        desc = f'[TEST] Acc (Linear): {100. * total_acc / (idx + 1):.1f}%'
        prog_bar.set_description(desc, refresh=True)

    # evaluation metric reset
    nice.reset()

    # Interrupt for sync GPU Process
    if args.distributed: dist.barrier()


def main(rank, args, ngpus_per_node):

    # setup ddp process
    if args.distributed: ddp_setup(args, rank, ngpus_per_node)

    # setting gpu id of this process
    torch.cuda.set_device(rank)

    # print argparse
    print_argparse(args, rank)

    # dataset loader
    train_loader, test_loader, sampler = dataloader(args)

    # network loader
    net = network_loader(args, rank)
    segment = segment_tr_loader(args, rank)
    cluster = cluster_tr_loader(args, rank)

    # distributed parsing
    if args.distributed: net = net.module; segment = segment.module; cluster = cluster.module

    # Bank and EMA
    cluster.bank_init()
    ema_init(segment.head, segment.head_ema)
    ema_init(segment.projection_head, segment.projection_head_ema)

    ###################################################################################
    # First, run train_mediator.py
    path, is_exist = pickle_path_and_exist(args)

    # early save for time
    if is_exist:
        # load
        codebook = np.load(path)
        cluster.codebook.data = torch.from_numpy(codebook).cuda()
        cluster.codebook.requires_grad = False
        segment.head.codebook = torch.from_numpy(codebook).cuda()
        segment.head_ema.codebook = torch.from_numpy(codebook).cuda()

        # print successful loading modularity
        rprint(f'Modularity {path} loaded', rank)

        # Interrupt for sync GPU Process
        if args.distributed: dist.barrier()

    else:
        rprint('Train Modularity-based Codebook First', rank)
        return
    ###################################################################################

    # optimizer
    if args.dataset=='cityscapes':
        optimizer_segment = torch.optim.Adam(segment.parameters(), lr=1e-3 * ngpus_per_node)
    else:
        optimizer_segment = torch.optim.Adam(segment.parameters(), lr=1e-3 * ngpus_per_node, weight_decay=1e-4)

    # tensorboard
    if (args.distributed == True) and (rank == 0):
        from datetime import datetime
        log_dir = os.path.join('logs',
                               datetime.today().strftime(" %m:%d_%H:%M")[2:],
                               args.dataset,
                               "_".join(
            [args.ckpt.split('/')[-1].split('.')[0],
             str(args.num_codebook),
             os.path.abspath(__file__).split('/')[-1]]))
        check_dir(log_dir)
    writer = SummaryWriter(log_dir=log_dir) if (rank == 0) and (args.distributed == True) else None

    # evaluation
    nice = NiceTool(args.n_classes)


    # train
    for epoch in range(args.epoch):

        # for shuffle
        if args.distributed: sampler.set_epoch(epoch)


        # train
        train(
            epoch,  # for decorator
            rank,  # for decorator
            args,
            net,
            segment,
            cluster,
            train_loader,
            optimizer_segment,
            writer, rank)


        test(
            epoch, # for decorator
            rank, # for decorator
            args,
            net,
            segment,
            nice,
            test_loader)

        if (rank == 0):
            x = segment.state_dict()
            baseline = args.ckpt.split('/')[-1].split('.')[0]

            # filepath hierarchy
            check_dir(f'CAUSE/{args.dataset}/{baseline}/{args.num_codebook}')

            # save path
            y = f'CAUSE/{args.dataset}/{baseline}/{args.num_codebook}/segment_tr.pth'
            torch.save(x, y)
            print(f'-----------------TEST Epoch {epoch}: SAVING CHECKPOINT IN {y}-----------------')

        # Interrupt for sync GPU Process
        if args.distributed: dist.barrier()

    # Closing DDP
    if args.distributed: dist.barrier(); dist.destroy_process_group()


if __name__ == "__main__":

    # fetch args
    parser = argparse.ArgumentParser()
    # model parameter
    parser.add_argument('--NAME-TAG', default='CAUSE-TR', type=str)
    parser.add_argument('--data_dir', default='../', type=str)
    parser.add_argument('--dataset', default='freiburg', type=str)
    parser.add_argument('--ckpt', default='checkpoint/dino_vit_base_8.pth', type=str)
    parser.add_argument('--epoch', default=2, type=int)
    # parser.add_argument('--distributed', default=True, type=str2bool)
    parser.add_argument('--distributed', default=False, type=str2bool)
    parser.add_argument('--load_segment', default=False, type=str2bool)
    parser.add_argument('--load_cluster', default=False, type=str2bool)
    parser.add_argument('--train_resolution', default=320, type=int)
    parser.add_argument('--test_resolution', default=320, type=int)
    parser.add_argument('--batch_size', default=16, type=int)
    # parser.add_argument('--num_workers', default=int(os.cpu_count() / 8), type=int)
    parser.add_argument('--num_workers', default=1, type=int)

    # DDP
    parser.add_argument('--gpu', default='0', type=str)
    parser.add_argument('--port', default='12355', type=str)
    
    # codebook parameter
    parser.add_argument('--grid', default='yes', type=str2bool)
    parser.add_argument('--num_codebook', default=2048, type=int)

    # model parameter
    parser.add_argument('--reduced_dim', default=90, type=int)
    parser.add_argument('--projection_dim', default=2048, type=int)

    args = parser.parse_args(args=[])

    if 'dinov2' in args.ckpt:
        args.train_resolution=322
        args.test_resolution=322
    if 'small' in args.ckpt:
        args.dim=384
    elif 'base' in args.ckpt:
        args.dim=768
    args.num_queries=args.train_resolution**2 // int(args.ckpt.split('_')[-1].split('.')[0])**2

    # the number of gpus for multi-process
    gpu_list = list(map(int, args.gpu.split(',')))
    ngpus_per_node = len(gpu_list)

    if args.distributed:
        # cuda visible devices
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
        # multiprocess spawn
        mp.spawn(main, args=(args, ngpus_per_node), nprocs=ngpus_per_node, join=True)
    else:
        # first gpu index is activated once there are several gpu in args.gpu
        main(rank=gpu_list[0], args=args, ngpus_per_node=1)


------------------Configurations------------------
NAME_TAG: CAUSE-TR
data_dir: ../
dataset: freiburg
ckpt: checkpoint/dino_vit_base_8.pth
epoch: 2
distributed: False
load_segment: False
load_cluster: False
train_resolution: 320
test_resolution: 320
batch_size: 16
num_workers: 1
gpu: 0
port: 12355
grid: True
num_codebook: 2048
reduced_dim: 90
projection_dim: 2048
dim: 768
num_queries: 1600
-------------------------------------------------
_IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=[])
Modularity CAUSE/freiburg/modularity/dino_vit_base_8/2048/modular.npy loaded
-------------TRAIN EPOCH: 1-------------


[Train] Loss: 7.22=6.56+0.66 ACC: 77.5%: 100%|██████████| 23/23 [00:10<00:00,  2.27it/s]

-------------TEST EPOCH: 1-------------



[TEST] Acc (Linear): 100.0%: 100%|██████████| 23/23 [00:05<00:00,  3.90it/s]

-----------------TEST Epoch 0: SAVING CHECKPOINT IN CAUSE/freiburg/dino_vit_base_8/2048/segment_tr.pth-----------------
-------------TRAIN EPOCH: 2-------------



[Train] Loss: 6.52=6.50+0.02 ACC: 100.0%: 100%|██████████| 23/23 [00:09<00:00,  2.33it/s]

-------------TEST EPOCH: 2-------------



[TEST] Acc (Linear): 100.0%: 100%|██████████| 23/23 [00:05<00:00,  3.87it/s]

-----------------TEST Epoch 1: SAVING CHECKPOINT IN CAUSE/freiburg/dino_vit_base_8/2048/segment_tr.pth-----------------





In [7]:
# fine tuning tr python file

import argparse

import torch.nn.init
from tqdm import tqdm
from utils.utils import *
from modules.segment_module import transform, untransform, compute_modularity_based_codebook
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn
from loader.dataloader import dataloader
from torch.cuda.amp import autocast, GradScaler
from loader.netloader import network_loader, segment_tr_loader, cluster_tr_loader

cudnn.benchmark = True
scaler = GradScaler()

cmap = create_pascal_label_colormap()

def ddp_setup(args, rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = args.port

    # initialize
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def ddp_clean():
    dist.destroy_process_group()


@Wrapper.EpochPrint
def train(args, net, segment, cluster, train_loader, optimizer_segment, optimizer_cluster):
    global counter
    segment.train()

    total_acc = 0
    total_loss = 0
    total_loss_linear = 0
    total_loss_mod = 0

    prog_bar = tqdm(enumerate(train_loader), total=len(train_loader), leave=True)
    for idx, batch in prog_bar:

        # optimizer
        with autocast():

            # image and label and self supervised feature
            img = batch["img"].cuda()
            label = batch["label"].cuda()

            # intermediate features
            feat = net(img)[:, 1:, :]
            seg_feat_ema = segment.head_ema(feat, segment.dropout)

            # computing modularity based codebook
            loss_mod = compute_modularity_based_codebook(cluster.cluster_probe, seg_feat_ema, grid=args.grid)

            # linear probe loss
            linear_logits = segment.linear(seg_feat_ema)
            linear_logits = F.interpolate(linear_logits, label.shape[-2:], mode='bilinear', align_corners=False)
            flat_linear_logits = linear_logits.permute(0, 2, 3, 1).reshape(-1, args.n_classes)
            flat_label = label.reshape(-1)
            flat_label_mask = (flat_label >= 0) & (flat_label < args.n_classes)
            loss_linear = F.cross_entropy(flat_linear_logits[flat_label_mask], flat_label[flat_label_mask])

            # loss
            loss = loss_linear + loss_mod

        # optimizer
        optimizer_segment.zero_grad()
        optimizer_cluster.zero_grad()
        scaler.scale(loss).backward()
        if args.dataset=='cityscapes':
            scaler.unscale_(optimizer_segment)
            torch.nn.utils.clip_grad_norm_(segment.parameters(), 1)
        elif args.dataset=='cocostuff27':
            scaler.unscale_(optimizer_segment)
            torch.nn.utils.clip_grad_norm_(segment.parameters(), 2)
        scaler.step(optimizer_segment)
        scaler.step(optimizer_cluster)
        scaler.update()

        # linear probe acc check
        pred_label = linear_logits.argmax(dim=1)
        flat_pred_label = pred_label.reshape(-1)
        acc = (flat_pred_label[flat_label_mask] == flat_label[flat_label_mask]).sum() / flat_label[
            flat_label_mask].numel()
        total_acc += acc.item()

        # loss check
        total_loss += loss.item()
        total_loss_linear += loss_linear.item()
        total_loss_mod += loss_mod.item()

        # real-time print
        desc = f'[Train] Loss: {total_loss / (idx + 1):.2f}={total_loss_linear / (idx + 1):.2f}{total_loss_mod / (idx + 1):.2f}'
        desc += f' ACC: {100. * total_acc / (idx + 1):.1f}%'
        prog_bar.set_description(desc, refresh=True)

        # Interrupt for sync GPU Process
        if args.distributed: dist.barrier()


@Wrapper.TestPrint
def test(args, net, segment, cluster, nice, test_loader):
    global counter_test
    segment.eval()

    total_acc = 0
    prog_bar = tqdm(enumerate(test_loader), total=len(test_loader), leave=True)
    for idx, batch in prog_bar:
        # image and label and self supervised feature
        img = batch["img"].cuda()
        label = batch["label"].cuda()

        # intermediate feature
        with autocast():

            feat = net(img)[:, 1:, :]
            seg_feat_ema = segment.head_ema(feat)

            # linear probe loss
            linear_logits = segment.linear(seg_feat_ema)
            linear_logits = F.interpolate(linear_logits, label.shape[-2:], mode='bilinear', align_corners=False)
            flat_label = label.reshape(-1)
            flat_label_mask = (flat_label >= 0) & (flat_label < args.n_classes)

            # interp feat
            interp_seg_feat = F.interpolate(transform(seg_feat_ema), label.shape[-2:], mode='bilinear', align_corners=False)

            # cluster
            cluster_preds = cluster.forward_centroid(untransform(interp_seg_feat), inference=True)

        # linear probe acc check
        pred_label = linear_logits.argmax(dim=1)
        flat_pred_label = pred_label.reshape(-1)
        acc = (flat_pred_label[flat_label_mask] == flat_label[flat_label_mask]).sum() / flat_label[
            flat_label_mask].numel()
        total_acc += acc.item()

        # nice evaluation
        _, desc_nice = nice.eval(cluster_preds, label)

        # real-time print
        desc = f'[TEST] Acc (Linear): {100. * total_acc / (idx + 1):.1f}% | {desc_nice}'
        prog_bar.set_description(desc, refresh=True)

    # evaludation metric reset
    nice.reset()

    # Interrupt for sync GPU Process
    if args.distributed: dist.barrier()





def main(rank, args, ngpus_per_node):

    # setup ddp process
    if args.distributed: ddp_setup(args, rank, ngpus_per_node)

    # setting gpu id of this process
    torch.cuda.set_device(rank)

    # print argparse
    print_argparse(args, rank)

    # dataset loader
    train_loader, test_loader, sampler = dataloader(args)

    # network loader
    net = network_loader(args, rank)
    segment = segment_tr_loader(args, rank)
    cluster = cluster_tr_loader(args, rank)

    # distributed parsing
    if args.distributed: net = net.module; segment = segment.module; cluster = cluster.module

    # optimizer
    if args.dataset=='cityscapes':
        optimizer_segment = torch.optim.Adam(segment.parameters(), lr=1e-3 * ngpus_per_node)
        optimizer_cluster = torch.optim.Adam(cluster.parameters(), lr=1e-3 * ngpus_per_node)
    else:
        optimizer_segment = torch.optim.Adam(segment.parameters(), lr=1e-3 * ngpus_per_node, weight_decay=1e-4)
        optimizer_cluster = torch.optim.Adam(cluster.parameters(), lr=1e-3 * ngpus_per_node)
    
    # scheduler
    scheduler_segment = torch.optim.lr_scheduler.StepLR(optimizer_segment, step_size=2, gamma=0.5)
    scheduler_cluster = torch.optim.lr_scheduler.StepLR(optimizer_cluster, step_size=2, gamma=0.5)

    # evaluation
    nice = NiceTool(args.n_classes)

    ###################################################################################
    # First, run train_mediator.py
    path, is_exist = pickle_path_and_exist(args)

    # early save for time
    if is_exist:
        # load
        codebook = np.load(path)
        cb = torch.from_numpy(codebook).cuda()
        cluster.codebook.data = cb
        cluster.codebook.requires_grad = False
        segment.head.codebook = cb
        segment.head_ema.codebook = cb

        # print successful loading modularity
        rprint(f'Modularity {path} loaded', rank)

        # Interrupt for sync GPU Process
        if args.distributed: dist.barrier()

    else:
        rprint('Train Modularity-based Codebook First', rank)
        return
    ###################################################################################


    # train
    for epoch in range(args.epoch):

        # for shuffle
        if args.distributed: sampler.set_epoch(epoch)


        # train
        train(
            epoch,  # for decorator
            rank,  # for decorator
            args,
            net,
            segment,
            cluster,
            train_loader,
            optimizer_segment,
            optimizer_cluster)

        test(
            epoch, # for decorator
            rank, # for decorator
            args,
            net,
            segment,
            cluster,
            nice,
            test_loader)

        scheduler_segment.step()
        scheduler_cluster.step()

        if (rank == 0):
            baseline = args.ckpt.split('/')[-1].split('.')[0]

            # filepath hierarchy
            check_dir(f'CAUSE/{args.dataset}/{baseline}/{args.num_codebook}')

            # save path
            y = f'CAUSE/{args.dataset}/{baseline}/{args.num_codebook}/segment_tr.pth'
            torch.save(segment.state_dict(), y)

            y = f'CAUSE/{args.dataset}/{baseline}/{args.num_codebook}/cluster_tr.pth'
            torch.save(cluster.state_dict(), y)
            print(f'-----------------TEST Epoch {epoch}: SAVING CHECKPOINT IN {y}-----------------')

        # Interrupt for sync GPU Process
        if args.distributed: dist.barrier()

    # Closing DDP
    if args.distributed: dist.barrier(); dist.destroy_process_group()

if __name__ == "__main__":

    # fetch args
    parser = argparse.ArgumentParser()
    # model parameter
    parser.add_argument('--NAME-TAG', default='CAUSE-TR', type=str)
    parser.add_argument('--data_dir', default='../', type=str)
    parser.add_argument('--dataset', default='freiburg', type=str)
    parser.add_argument('--ckpt', default='checkpoint/dino_vit_base_8.pth', type=str)
    parser.add_argument('--epoch', default=5, type=int)
    parser.add_argument('--distributed', default=False, type=str2bool)
    parser.add_argument('--load_segment', default=True, type=str2bool)
    parser.add_argument('--load_cluster', default=False, type=str2bool)
    parser.add_argument('--train_resolution', default=320, type=int)
    parser.add_argument('--test_resolution', default=320, type=int)
    parser.add_argument('--batch_size', default=16, type=int)
    parser.add_argument('--num_workers', default=3, type=int)

    # DDP
    parser.add_argument('--gpu', default='0', type=str)
    parser.add_argument('--port', default='12355', type=str)
    
    # codebook parameter
    parser.add_argument('--grid', default='yes', type=str2bool)
    parser.add_argument('--num_codebook', default=2048, type=int)

    # model parameter
    parser.add_argument('--reduced_dim', default=90, type=int)
    parser.add_argument('--projection_dim', default=2048, type=int)

    args = parser.parse_args(args=[])

    if 'dinov2' in args.ckpt:
        args.train_resolution=322
        args.test_resolution=322
    if 'small' in args.ckpt:
        args.dim=384
    elif 'base' in args.ckpt:
        args.dim=768
    args.num_queries=args.train_resolution**2 // int(args.ckpt.split('_')[-1].split('.')[0])**2

    # the number of gpus for multi-process
    gpu_list = list(map(int, args.gpu.split(',')))
    ngpus_per_node = len(gpu_list)

    if args.distributed:
        # cuda visible devices
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
        # multiprocess spawn
        mp.spawn(main, args=(args, ngpus_per_node), nprocs=ngpus_per_node, join=True)
    else:
        # first gpu index is activated once there are several gpu in args.gpu
        main(rank=gpu_list[0], args=args, ngpus_per_node=1)



------------------Configurations------------------
NAME_TAG: CAUSE-TR
data_dir: ../
dataset: freiburg
ckpt: checkpoint/dino_vit_base_8.pth
epoch: 5
distributed: False
load_segment: True
load_cluster: False
train_resolution: 320
test_resolution: 320
batch_size: 16
num_workers: 3
gpu: 0
port: 12355
grid: True
num_codebook: 2048
reduced_dim: 90
projection_dim: 2048
dim: 768
num_queries: 1600
-------------------------------------------------
_IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=[])
[Segment] CAUSE/freiburg/dino_vit_base_8/2048/segment_tr.pth loaded
Modularity CAUSE/freiburg/modularity/dino_vit_base_8/2048/modular.npy loaded
-------------TRAIN EPOCH: 1-------------


[Train] Loss: -0.00=0.00-0.00 ACC: 100.0%: 100%|██████████| 23/23 [01:59<00:00,  5.22s/it]

-------------TEST EPOCH: 1-------------



[TEST] Acc (Linear): 100.0% | [mIoU]: 14.9, [mAP]: 25.0, [Acc]: 59.5, : 100%|██████████| 23/23 [02:37<00:00,  6.83s/it]

-----------------TEST Epoch 0: SAVING CHECKPOINT IN CAUSE/freiburg/dino_vit_base_8/2048/cluster_tr.pth-----------------
-------------TRAIN EPOCH: 2-------------



[Train] Loss: -0.00=0.00-0.00 ACC: 100.0%: 100%|██████████| 23/23 [01:50<00:00,  4.82s/it]

-------------TEST EPOCH: 2-------------



[TEST] Acc (Linear): 100.0% | [mIoU]: 20.5, [mAP]: 33.3, [Acc]: 61.5, : 100%|██████████| 23/23 [02:11<00:00,  5.72s/it]

-----------------TEST Epoch 1: SAVING CHECKPOINT IN CAUSE/freiburg/dino_vit_base_8/2048/cluster_tr.pth-----------------
-------------TRAIN EPOCH: 3-------------



[Train] Loss: -0.00=0.00-0.00 ACC: 100.0%: 100%|██████████| 23/23 [02:02<00:00,  5.35s/it]

-------------TEST EPOCH: 3-------------



[TEST] Acc (Linear): 100.0% | [mIoU]: 26.1, [mAP]: 50.0, [Acc]: 52.3, : 100%|██████████| 23/23 [02:04<00:00,  5.40s/it]

-----------------TEST Epoch 2: SAVING CHECKPOINT IN CAUSE/freiburg/dino_vit_base_8/2048/cluster_tr.pth-----------------
-------------TRAIN EPOCH: 4-------------



[Train] Loss: -0.00=0.00-0.00 ACC: 100.0%: 100%|██████████| 23/23 [02:05<00:00,  5.45s/it]

-------------TEST EPOCH: 4-------------



[TEST] Acc (Linear): 100.0% | [mIoU]: 28.2, [mAP]: 50.0, [Acc]: 56.3, : 100%|██████████| 23/23 [02:20<00:00,  6.10s/it]

-----------------TEST Epoch 3: SAVING CHECKPOINT IN CAUSE/freiburg/dino_vit_base_8/2048/cluster_tr.pth-----------------
-------------TRAIN EPOCH: 5-------------



[Train] Loss: -0.00=0.00-0.00 ACC: 100.0%: 100%|██████████| 23/23 [01:54<00:00,  4.98s/it]

-------------TEST EPOCH: 5-------------



[TEST] Acc (Linear): 100.0% | [mIoU]: 28.1, [mAP]: 50.0, [Acc]: 56.3, : 100%|██████████| 23/23 [01:49<00:00,  4.76s/it]

-----------------TEST Epoch 4: SAVING CHECKPOINT IN CAUSE/freiburg/dino_vit_base_8/2048/cluster_tr.pth-----------------





In [8]:
# test_tr file here 


import argparse

from tqdm import tqdm
from utils.utils import *
from modules.segment_module import transform, untransform
from loader.dataloader import dataloader
from torch.cuda.amp import autocast
from loader.netloader import network_loader, segment_tr_loader, cluster_tr_loader


def test(args, net, segment, cluster, nice, test_loader, cmap):
    segment.eval()

    prog_bar = tqdm(enumerate(test_loader), total=len(test_loader), leave=True)
    # originally Pool(40), but most computers do not have 40 cores
    
    # with Pool(20) as pool:
    for _, batch in prog_bar:
        # image and label and self supervised feature
        ind = batch["ind"].cuda()
        img = batch["img"].cuda()
        label = batch["label"].cuda()
        
        # print('starting autocast')
        with autocast():
            # intermediate feature
            feat = net(img)[:, 1:, :]
            feat_flip = net(img.flip(dims=[3]))[:, 1:, :]
        seg_feat = transform(segment.head_ema(feat))
        seg_feat_flip = transform(segment.head_ema(feat_flip))
        seg_feat = untransform((seg_feat + seg_feat_flip.flip(dims=[3])) / 2)

        # print('starting interp')
        # interp feat
        interp_seg_feat = F.interpolate(transform(seg_feat), label.shape[-2:], mode='bilinear', align_corners=False)
# 
        # print('starting cluster')
        # cluster preds
        cluster_preds = cluster.forward_centroid(untransform(interp_seg_feat), crf=True)

        # print('starting crf')
        # crf
        # crf_preds = do_crf(pool, img, cluster_preds).argmax(1).cuda()
        crf_preds = do_crf( img, cluster_preds).argmax(1).cuda()

        # print('starting nice')
        # nice evaluation
        _, desc_nice = nice.eval(crf_preds, label)

        # print('starting hungarian')
        # hungarian
        hungarian_preds = nice.do_hungarian(crf_preds)

        # print('starting save')
        # save images
        save_all(args, ind, img, label, cluster_preds.argmax(dim=1), crf_preds, hungarian_preds, cmap, is_tr=True)

        # real-time print
        desc = f'{desc_nice}'
        prog_bar.set_description(desc, refresh=True)

    # evaludation metric reset
    nice.reset()



def test_without_crf(args, net, segment, cluster, nice, test_loader):
    segment.eval()

    total_acc = 0
    prog_bar = tqdm(enumerate(test_loader), total=len(test_loader), leave=True)
    for idx, batch in prog_bar:
        # image and label and self supervised feature
        ind = batch["ind"].cuda()
        img = batch["img"].cuda()
        label = batch["label"].cuda()

        cmap = create_pascal_label_colormap()
        a = invTrans(img)[0].permute(1,2,0)
        b = cmap[label[0].cpu()]

        # intermediate feature
        with autocast():

            feat = net(img)[:, 1:, :]
            seg_feat_ema = segment.head_ema(feat)

            # linear probe loss
            linear_logits = segment.linear(seg_feat_ema)
            linear_logits = F.interpolate(linear_logits, label.shape[-2:], mode='bilinear', align_corners=False)
            flat_label = label.reshape(-1)
            flat_label_mask = (flat_label >= 0) & (flat_label < args.n_classes)

            # interp feat
            interp_seg_feat = F.interpolate(transform(seg_feat_ema), label.shape[-2:], mode='bilinear', align_corners=False)

            # cluster
            cluster_preds = cluster.forward_centroid(untransform(interp_seg_feat), inference=True)

            # nice evaluation
            _, desc_nice = nice.eval(cluster_preds, label)

        # linear probe acc check
        pred_label = linear_logits.argmax(dim=1)
        flat_pred_label = pred_label.reshape(-1)
        acc = (flat_pred_label[flat_label_mask] == flat_label[flat_label_mask]).sum() / flat_label[
            flat_label_mask].numel()
        total_acc += acc.item()

        # real-time print
        desc = f'[TEST] Acc (Linear): {100. * total_acc / (idx + 1):.1f}% | {desc_nice}'
        prog_bar.set_description(desc, refresh=True)

    # evaludation metric reset
    nice.reset()


def test_linear_without_crf(args, net, segment, nice, test_loader):
    segment.eval()

    prog_bar = tqdm(enumerate(test_loader), total=len(test_loader), leave=True)
    with Pool(40) as pool:
        for _, batch in prog_bar:
            # image and label and self supervised feature
            ind = batch["ind"].cuda()
            img = batch["img"].cuda()
            label = batch["label"].cuda()

            with autocast():
                # intermediate feature
                feat = net(img)[:, 1:, :]
                feat_flip = net(img.flip(dims=[3]))[:, 1:, :]
            seg_feat = transform(segment.head_ema(feat))
            seg_feat_flip = transform(segment.head_ema(feat_flip))
            seg_feat = untransform((seg_feat + seg_feat_flip.flip(dims=[3])) / 2)

            # interp feat
            interp_seg_feat = F.interpolate(transform(seg_feat), label.shape[-2:], mode='bilinear', align_corners=False)

            # linear probe interp feat
            linear_logits = segment.linear(untransform(interp_seg_feat))

            # cluster preds
            cluster_preds = linear_logits.argmax(dim=1)

            # nice evaluation
            _, desc_nice = nice.eval(cluster_preds, label)

            # real-time print
            desc = f'{desc_nice}'
            prog_bar.set_description(desc, refresh=True)

    # evaludation metric reset
    nice.reset()



def test_linear(args, net, segment, nice, test_loader):
    segment.eval()

    prog_bar = tqdm(enumerate(test_loader), total=len(test_loader), leave=True)
    with Pool(40) as pool:
        for _, batch in prog_bar:
            # image and label and self supervised feature
            ind = batch["ind"].cuda()
            img = batch["img"].cuda()
            label = batch["label"].cuda()

            with autocast():
                # intermediate feature
                feat = net(img)[:, 1:, :]
                feat_flip = net(img.flip(dims=[3]))[:, 1:, :]
            seg_feat = transform(segment.head_ema(feat))
            seg_feat_flip = transform(segment.head_ema(feat_flip))
            seg_feat = untransform((seg_feat + seg_feat_flip.flip(dims=[3])) / 2)

            # interp feat
            interp_seg_feat = F.interpolate(transform(seg_feat), label.shape[-2:], mode='bilinear', align_corners=False)

            # linear probe interp feat
            linear_logits = segment.linear(untransform(interp_seg_feat))

            # cluster preds
            cluster_preds = torch.log_softmax(linear_logits, dim=1)

            # crf
            crf_preds = do_crf(pool, img, cluster_preds).argmax(1).cuda()

            # nice evaluation
            _, desc_nice = nice.eval(crf_preds, label)

            # real-time print
            desc = f'{desc_nice}'
            prog_bar.set_description(desc, refresh=True)

    # evaludation metric reset
    nice.reset()


def main(rank, args):

    # setting gpu id of this process
    torch.cuda.set_device(rank)

    # print argparse
    print_argparse(args, rank=0)

    # dataset loader
    train_loader, test_loader, _ = dataloader(args, False)

    # network loader
    net = network_loader(args, rank)
    segment = segment_tr_loader(args, rank)
    cluster = cluster_tr_loader(args, rank)

    # evaluation
    nice = NiceTool(args.n_classes)

    # color map
    cmap = create_cityscapes_colormap() if args.dataset == 'cityscapes' else create_pascal_label_colormap()


    ###################################################################################
    # First, run train_mediator.py
    path, is_exist = pickle_path_and_exist(args)

    # early save for time
    if is_exist:
        # load
        codebook = np.load(path)
        cb = torch.from_numpy(codebook).cuda()
        cluster.codebook.data = cb
        cluster.codebook.requires_grad = False
        segment.head.codebook = cb
        segment.head_ema.codebook = cb

        # print successful loading modularity
        rprint(f'Modularity {path} loaded', rank)

    else:
        rprint('Train Modularity-based Codebook First', rank)
        return
    ###################################################################################

    # param size
    print(f'# of Parameters: {num_param(segment)/10**6:.2f}(M)') 

    # post-processing with crf and hungarian matching
    test_without_crf(
        args,
        net,
        segment,
        cluster,
        nice,
        test_loader)

    print('done test_without_crf, starting test')
    # post-processing with crf and hungarian matching
    test(
        args,
        net,
        segment,
        cluster,
        nice,
        test_loader,
        cmap)
    
    # post-processing with crf and hungarian matching
    # test_linear_without_crf(
    #     args,
    #     net,
    #     segment,
    #     nice,
    #     test_loader)
    
    # test_linear(
    #     args,
    #     net,
    #     segment,
    #     nice,
    #     test_loader)


if __name__ == "__main__":

    # fetch args
    parser = argparse.ArgumentParser()
    
    # model parameter
    parser.add_argument('--NAME-TAG', default='CAUSE-TR', type=str)
    parser.add_argument('--data_dir', default='../', type=str)
    parser.add_argument('--dataset', default='freiburg', type=str)
    parser.add_argument('--port', default='12355', type=str)
    parser.add_argument('--ckpt', default='checkpoint/dino_vit_base_8.pth', type=str)
    parser.add_argument('--distributed', default=False, type=str2bool)
    parser.add_argument('--load_segment', default=True, type=str2bool)
    parser.add_argument('--load_cluster', default=True, type=str2bool)
    parser.add_argument('--train_resolution', default=320, type=int)
    parser.add_argument('--test_resolution', default=320, type=int)
    parser.add_argument('--batch_size', default=16, type=int)
    parser.add_argument('--num_workers', default=1, type=int)
    parser.add_argument('--gpu', default='0', type=str)
    parser.add_argument('--num_codebook', default=2048, type=int)

    # model parameter
    parser.add_argument('--reduced_dim', default=90, type=int)
    parser.add_argument('--projection_dim', default=2048, type=int)

    args = parser.parse_args(args=[])


    if 'dinov2' in args.ckpt:
        args.train_resolution=322
        args.test_resolution=322
    if 'small' in args.ckpt:
        args.dim=384
    elif 'base' in args.ckpt:
        args.dim=768
    args.num_queries=args.train_resolution**2 // int(args.ckpt.split('_')[-1].split('.')[0])**2
    

    # the number of gpus for multi-process
    gpu_list = list(map(int, args.gpu.split(',')))
    ngpus_per_node = len(gpu_list)

    # first gpu index is activated once there are several gpu in args.gpu
    main(rank=gpu_list[0], args=args)

------------------Configurations------------------
NAME_TAG: CAUSE-TR
data_dir: ../
dataset: freiburg
port: 12355
ckpt: checkpoint/dino_vit_base_8.pth
distributed: False
load_segment: True
load_cluster: True
train_resolution: 320
test_resolution: 320
batch_size: 16
num_workers: 1
gpu: 0
num_codebook: 2048
reduced_dim: 90
projection_dim: 2048
dim: 768
num_queries: 1600
-------------------------------------------------
_IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=[])
[Segment] CAUSE/freiburg/dino_vit_base_8/2048/segment_tr.pth loaded
[Cluster] CAUSE/freiburg/dino_vit_base_8/2048/cluster_tr.pth loaded
Modularity CAUSE/freiburg/modularity/dino_vit_base_8/2048/modular.npy loaded
# of Parameters: 9.84(M)


[TEST] Acc (Linear): 100.0% | [mIoU]: 28.1, [mAP]: 50.0, [Acc]: 56.3, : 100%|██████████| 23/23 [02:22<00:00,  6.20s/it]

done test_without_crf, starting test



[mIoU]: 29.4, [mAP]: 50.0, [Acc]: 58.7, : 100%|██████████| 23/23 [08:22<00:00, 21.84s/it]
