# Play with AViT

Use AViT (ViT-B based) as an example. Load the model weights, and test the model on the 4th folder of the ISIC dataset.

In [17]:
import torch
import yaml
from Utils.pieces import DotDict
from os.path import join
from Datasets.create_dataset import Dataset_wrap_csv
import medpy.metric.binary as metrics
from tqdm import tqdm

In [20]:
EXP_PATH = '/ubc/ece/home/ra/grads/siyi/Research/skin_lesion_segmentation/results/ISICW/K_results/K3_ISIC/ca_R34_v4_isic2018_ViTSeg_CNNprompt_adapt_20230508_1840'
dataset_name = 'isic2018'
config_yml = join(EXP_PATH, 'exp_config.yml')
ckpt_path = join(EXP_PATH, 'best.pth')
config = yaml.load(open(config_yml), Loader=yaml.FullLoader)
config = DotDict(config)
config.data.data_folder = '/bigdata/siyiplace/data/skin_lesion'

In [21]:
datas = Dataset_wrap_csv(k_fold=config.data.k_fold, use_old_split=True, img_size=config.data.img_size, 
    dataset_name = dataset_name, split_ratio=config.data.split_ratio, 
    train_aug=config.data.train_aug, data_folder=config.data.data_folder)
train_data, val_data, test_data = datas['train'], datas['test'], datas['test']

train_loader = torch.utils.data.DataLoader(train_data,
                                        batch_size=config.train.batch_size,
                                        shuffle=True,
                                        num_workers=config.train.num_workers,
                                        pin_memory=True,
                                        drop_last=True)
val_loader = torch.utils.data.DataLoader(val_data,
                                        batch_size=config.test.batch_size,
                                        shuffle=False,
                                        num_workers=config.test.num_workers,
                                        pin_memory=True,
                                        drop_last=False)
test_loader = torch.utils.data.DataLoader(test_data,
                                        batch_size=config.test.batch_size,
                                        shuffle=False,
                                        num_workers=config.test.num_workers,
                                        pin_memory=True,
                                        drop_last=False)

/bigdata/siyiplace/data/skin_lesion/isic2018/
isic2018 has 2594 samples, 2076 are used to train, 518 are used to test. 
 5 Folder -- Use 4


In [13]:
if config.model == 'ViTSeg_CNNprompt_adapt':
    from Models.Transformer.ViT_adapters import ViTSeg_CNNprompt_adapt
    model = ViTSeg_CNNprompt_adapt(pretrained=False, pretrained_vit_name=config.vit.name,
    pretrained_folder=config.pretrained_folder,img_size=config.data.img_size, patch_size=config.vit.patch_size,
    embed_dim=config.vit.embed_dim, depth=config.vit.depth, num_heads=config.vit.num_heads, 
    mlp_ratio=config.vit.mlp_ratio, drop_rate=config.vit.dropout_rate, 
    attn_drop_rate=config.vit.attention_dropout_rate, drop_path_rate=0.2, 
    debug=config.debug, adapt_method=config.model_adapt.adapt_method, num_domains=1)

In [14]:
model.load_state_dict(torch.load(ckpt_path))
if torch.cuda.is_available():
    model = model.cuda()

In [22]:
model.eval()
dice_test_sum, iou_test_sum, num_test = 0, 0, 0
for batch in tqdm(test_loader):
    img = batch['image'].cuda().float()
    label = batch['label'].cuda().float()
    # domain label
    d = '0'  
    batch_len = img.shape[0]
    with torch.no_grad():
        output = model(img,d=d)['seg']
        output = torch.sigmoid(output)
        # calculate metrics
        output = output.cpu().numpy() > 0.5
        label = label.cpu().numpy()
        dice_test_sum += metrics.dc(output, label)*batch_len
        iou_test_sum += metrics.jc(output, label)*batch_len
        num_test += batch_len

print(f'{dataset_name}, Dice: {dice_test_sum/num_test}, IOU: {iou_test_sum/num_test}')

100%|██████████| 130/130 [00:20<00:00,  6.47it/s]

isic2018, Dice: 0.9200534490090431, IOU: 0.8558168345745101



