In [1]:
import sys
sys.executable

'/home/rs37890/.conda/envs/UDA/bin/python'

In [2]:
import os
import re
import sys
import time
import copy
import h5py
import random
import logging
import argparse
import numpy as np
import pandas as pd
from PIL import Image

In [3]:
import torch
import torch.nn as nn

In [4]:
from torchvision import *
import torchvision.models

In [5]:
from sklearn.metrics import roc_auc_score

# args

In [6]:
parser = argparse.ArgumentParser()

parser.add_argument('--report-freq', type=int, default=1, help='logging frequency')
parser.add_argument('--tune-mode', type=str, default='fine-tune', choices=['fine-tune', 'feature-extract'], help='tuning mode' )
parser.add_argument('--backbone', type=str, default='resnet50', choices=['resnet50', 'vgg19', 'inception_v3'], 
                    help='backbone architecture' )
parser.add_argument('--cls-type', type=str, default='single', choices=['single', 
                   'double', 'double-bn', 'double-dropout'], help='classifier architecture' )
parser.add_argument('--hidden-dim', type=int, default=512, help='hidden dimension of classifier' )
parser.add_argument('--record-root-dir', type=str, default='./record-data', help='record data root dir' )
parser.add_argument('--exp', type=str, default='default_exp', help='name of experiment' )
parser.add_argument('--batch-size', type=int, default=8, help='batch size' )
parser.add_argument('--num-workers', type=int, default=0, help='number of processes working on cpu.')
parser.add_argument('--num-classes', type=int, default=5, help='number of classes')
parser.add_argument('--num-epochs', type=int, default=20,  help='number of epochs.')
parser.add_argument( '--num-steps', type=int, default=-1, help='number of steps per epoch. '+ '-1 means use entire data' )
parser.add_argument('--learn-rate', type=float, default=1e-3, help='learning rate for gradient descent')
parser.add_argument('--weight-decay', type=float, default=1e-3, help='weight decay for optimization')
parser.add_argument('--resume', action='store_true', help='resume experiment <exp> from last checkpoint' )
parser.add_argument('--input-dir', type=str, default= r'../h5py/', help='data root dir' )
parser.add_argument('--save-name', type=str, default='model.pt', help='saved model name' )

args, unknown =  parser.parse_known_args()

In [7]:
# device
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
args.device = DEVICE
args.gpu_ids = [0]

In [8]:
args.num_epochs = 100
args.learn_rate = 1e-6
args.weight_decay = 1e-6 
args.batch_size = 120
args.num_classes = 1
args.seed = 1111

In [9]:
args.root = '/scratch/rs37890/CARC/Explainable-NN-model'
args.sub_root = '/Fold1_Bio-Unet-stage1'
args.checkpoint_root = '/F(CLR) + F(Res) + F(Seg)'
args.shallow_network_folder = '/Mel-detection_effect_of_different_lr/1e-6'

args.csv_h5_dir = args.root + '/Data'

args.record_dir = args.root  + args.sub_root + args.shallow_network_folder + '/record_dir'
args.save_name = 'Resnet50'

# dataset

In [10]:
def get_transform(split):
    
    img_size = ( 512, 512 )
    Normal_transform = transforms.Compose([ transforms.Resize( img_size ),
                                            transforms.ToTensor(),
                                            transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])],
                                         )

    return Normal_transform

In [11]:
class SkinDataset( torch.utils.data.Dataset ):

    def __init__( self, h5_file_path, transform, split, args):
        self.h5_file_path = h5_file_path
        self.h5_file = None
        self.transform = transform
        self.split = split
        self.img_id_to_h5idx = self.build_img_id_to_h5idx()
        self.num_imgs = self.get_num_imgs() 

        self.args = args

        self.ATTR_TO_INDEX = { 'globules': 0,
                               'milia_like_cyst': 1,
                               'negative_network': 2,
                               'pigment_network': 3,
                               'streaks': 4,
                             }

    def get_num_imgs( self ):
        with h5py.File( self.h5_file_path, 'r' ) as f:
            return len( f['image_ids'] )

    def build_img_id_to_h5idx( self ):
        with h5py.File( self.h5_file_path, 'r' ) as f:
            
            img_ids = f['image_ids']
            img_id_to_h5idx = { img_id : idx for idx, img_id in enumerate( img_ids ) }
            
            return img_id_to_h5idx

    def __len__( self ):
        return self.get_num_imgs()

    def __getitem__( self, idx ):
        # import pdb; pdb.set_trace()
        if not self.h5_file:
            self.h5_file = h5py.File( self.h5_file_path, 'r' )
            
        img_id = self.h5_file['image_ids'][idx]
        img = self.h5_file['images'][idx]
        img = img.transpose([1, 2, 0])

        assert img.shape == (512, 512, 3)

        img = Image.fromarray( np.uint8(img) )
        
        if self.split == 'unsup':
            img1 = self.transform( img )
            img2 = self.transform( img )
            return img_id, img1, img2
        
        else:
            masks = self.h5_file['masks'][idx]
            labels = self.h5_file['labels'][idx].astype(np.float64)
            img = self.transform( img )

            # labels[0] globus
            # labels[1] milia_like_cyst
            # labels[2] negative
            # labels[3] pigment
            # labels[4] streaks
            # labels[5] melanoma
            
            index = 5

            assert masks.shape == (5, 512, 512)
            assert np.expand_dims(labels[index], axis=0).shape == (1,)
            
            return img_id, img, np.expand_dims(labels[index], axis=0), masks

In [12]:
def get_dataloader( args ):
    
    dataloader = {}
    splits = [ 'train', 'val' ]
    
    for split in splits:
        
        h5_file_path = os.path.join( args.csv_h5_dir, f'{split}.h5' )
        transform = get_transform(split)

        bsz = args.batch_size
        
        dataset = SkinDataset( h5_file_path, transform, split, args)
        loader = torch.utils.data.DataLoader( dataset = dataset,
                                              batch_size = bsz,
                                              shuffle = True if split != 'val' else False,
                                              num_workers = args.num_workers,
                                              drop_last = False,
                                            )
        
        dataloader[ split ] = loader
        
    return dataloader

# Load model

In [13]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.base_model = torchvision.models.resnet50(pretrained=True)
        self.base_layers = list(self.base_model.children())
        
        self.layer0 = nn.Sequential(*self.base_layers[:3])
        self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 256, x.H/4, x.W/4)
        self.layer2 = self.base_layers[5]  # size=(N, 512, x.H/8, x.W/8)
        self.layer3 = self.base_layers[6]  # size=(N, 1024, x.H/16, x.W/16)
        self.layer4 = self.base_layers[7]  # size=(N, 2048, x.H/32, x.W/32)
        self.Avgpooling = nn.Sequential(self.base_layers[8])
        
        self.mlp = torch.nn.Sequential( torch.nn.Linear(2048,1), )# only one linear layer on top
        
    def forward(self, X):
        layer0 = self.layer0(X)  # layer0:  torch.Size([1, 64, 256, 256])
        layer1 = self.layer1(layer0) # layer1:  torch.Size([1, 256, 128, 128])
        layer2 = self.layer2(layer1) # layer2:  torch.Size([1, 512, 64, 64])
        layer3 = self.layer3(layer2) # layer3:  torch.Size([1, 1024, 32, 32])
        layer4 = self.layer4(layer3) # layer4:  torch.Size([1, 2048, 16, 16])
        out    = self.Avgpooling(layer4)
        
        out1 = torch.squeeze(out) 
        if X.size(0) == 1: # torch.Size([1, 2048, 1, 1]) => torch.Size([2048])
            out1 = torch.unsqueeze(out1, 0) # torch.Size([1, 2048])
        out = self.mlp(out1)
        return out1, out

# Streaks

In [14]:
net = Model()
streaks_net = torch.nn.DataParallel(net)
streaks_net = streaks_net.module.to( DEVICE )



In [15]:
#Model path
Main_folder = args.root  + args.sub_root + args.checkpoint_root
attribute = "/streaks"
model_name = '/SIMCLR+Resnet50+Unet_model.pt'
best_model_name = '/SIMCLR+Resnet50+Unet_best_model.pt'
streaks_Model_path = Main_folder + attribute + '/record_dir' + best_model_name
streaks_Model_path

'/scratch/rs37890/CARC/Explainable-NN-model/Fold1_Bio-Unet-stage1/F(CLR) + F(Res) + F(Seg)/streaks/record_dir/SIMCLR+Resnet50+Unet_best_model.pt'

In [16]:
checkpoint = torch.load(streaks_Model_path)
streaks_net.load_state_dict(checkpoint)

<All keys matched successfully>

# Pigment

In [17]:
net = Model()
pigment_net = torch.nn.DataParallel(net)
pigment_net = pigment_net.to( DEVICE )
pigment_net = pigment_net.module

In [18]:
#Model path
Main_folder = args.root  + args.sub_root + args.checkpoint_root
attribute = "/pigment"
model_name = '/SIMCLR+Resnet50+Unet_model.pt'
best_model_name = '/SIMCLR+Resnet50+Unet_best_model.pt'
pigment_Model_path = Main_folder + attribute + '/record_dir' + best_model_name
pigment_Model_path

'/scratch/rs37890/CARC/Explainable-NN-model/Fold1_Bio-Unet-stage1/F(CLR) + F(Res) + F(Seg)/pigment/record_dir/SIMCLR+Resnet50+Unet_best_model.pt'

In [19]:
checkpoint = torch.load(pigment_Model_path)
pigment_net.load_state_dict(checkpoint)

<All keys matched successfully>

# Negative

In [20]:
net = Model()
negative_net = torch.nn.DataParallel(net)
negative_net = negative_net.to( DEVICE )
negative_net = negative_net.module

In [21]:
#Model path
Main_folder = args.root  + args.sub_root + args.checkpoint_root
attribute = "/negative"
model_name = '/SIMCLR+Resnet50+Unet_model.pt'
best_model_name = '/SIMCLR+Resnet50+Unet_best_model.pt'
negative_Model_path = Main_folder + attribute + '/record_dir' + best_model_name
negative_Model_path

'/scratch/rs37890/CARC/Explainable-NN-model/Fold1_Bio-Unet-stage1/F(CLR) + F(Res) + F(Seg)/negative/record_dir/SIMCLR+Resnet50+Unet_best_model.pt'

In [22]:
checkpoint = torch.load(negative_Model_path)
negative_net.load_state_dict(checkpoint)

<All keys matched successfully>

# Milia

In [23]:
net = Model()
milia_net = torch.nn.DataParallel(net)
milia_net = milia_net.to( DEVICE )
milia_net = milia_net.module

In [24]:
#Model path
Main_folder = args.root  + args.sub_root + args.checkpoint_root
attribute = "/milia_like_cyst"
model_name = '/SIMCLR+Resnet50+Unet_model.pt'
best_model_name = '/SIMCLR+Resnet50+Unet_best_model.pt'
milia_like_cyst_Model_path = Main_folder + attribute + '/record_dir' + best_model_name
milia_like_cyst_Model_path

'/scratch/rs37890/CARC/Explainable-NN-model/Fold1_Bio-Unet-stage1/F(CLR) + F(Res) + F(Seg)/milia_like_cyst/record_dir/SIMCLR+Resnet50+Unet_best_model.pt'

In [25]:
checkpoint = torch.load(milia_like_cyst_Model_path)
milia_net.load_state_dict(checkpoint)

<All keys matched successfully>

# Globules

In [26]:
net = Model()
globules_net = torch.nn.DataParallel(net)
globules_net = globules_net.to( DEVICE )
globules_net = globules_net.module

In [27]:
#Model path
Main_folder = args.root  + args.sub_root + args.checkpoint_root
attribute = "/globules"
model_name = '/SIMCLR+Resnet50+Unet_model.pt'
best_model_name = '/SIMCLR+Resnet50+Unet_best_model.pt'
globules_Model_path = Main_folder + attribute + '/record_dir' + best_model_name
globules_Model_path

'/scratch/rs37890/CARC/Explainable-NN-model/Fold1_Bio-Unet-stage1/F(CLR) + F(Res) + F(Seg)/globules/record_dir/SIMCLR+Resnet50+Unet_best_model.pt'

In [28]:
checkpoint = torch.load(globules_Model_path)
globules_net.load_state_dict(checkpoint)

<All keys matched successfully>

# softmax

In [29]:
class LogisticRegression(torch.nn.Module): 
    def __init__(self, input_dim = 10240, output_dim = 1): 
        super(LogisticRegression, self).__init__() 
        #self.linear = torch.nn.Linear(input_dim, output_dim) 
        self.linear1 = torch.nn.Linear(input_dim, 2048) 
        self.linear2 = torch.nn.Linear(2048, 512) 
        self.linear3 = torch.nn.Linear(512, output_dim) 
        
    def forward(self, x): 
        #outputs = torch.sigmoid(self.linear(x)) 
        output1 = self.linear1(x)
        output2 = self.linear2(output1)
        output3 = self.linear3(output2)
        
        return nn.Sigmoid()(output3).type(torch.float64)

# Trainer

In [30]:
class AverageMeter:
    """Computes and stores the average and current value."""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [31]:
class Trainer:
    def __init__(self, 
                 model, 
                 globules_net , milia_net, negative_net, pigment_net, streaks_net,
                 dataloader, 
                 args,
                ):
        
        self.model = model.to(args.device)
        self.dataloader = dataloader
        self.args = args
        self.best_loss = np.inf
        self.best_epoch = 0
        self.best_model = None
        self.current_epoch = 0

        self.globules_net = globules_net.to(args.device)
        self.milia_net = milia_net.to(args.device)
        self.negative_net = negative_net.to(args.device)
        self.pigment_net = pigment_net.to(args.device)
        self.streaks_net = streaks_net.to(args.device)


        self.globules_net.eval()
        self.milia_net.eval()
        self.negative_net.eval()
        self.pigment_net.eval()
        self.streaks_net.eval()
        

        self.criterion = nn.BCELoss()
        self.optimizer = torch.optim.Adam( params = self.model.parameters(),
                                           lr = self.args.learn_rate,
                                           weight_decay = self.args.weight_decay,
                                         )

        self.logger = self.setup_logger()
        self.setup_experiment()
        self.set_random_seeds()

    def setup_logger(self):
        log_file = 'log.txt'
        if os.path.exists(log_file):
            os.remove(log_file)  # Remove the existing log file if it exists
        
        logger = logging.getLogger('Trainer')
        logger.setLevel(logging.INFO)
    
        # File Handler for log files
        fh = logging.FileHandler(log_file)
        fh.setLevel(logging.INFO)
        file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        fh.setFormatter(file_formatter)
        logger.addHandler(fh)
    
        # Stream Handler for console output
        ch = logging.StreamHandler()
        ch.setLevel(logging.INFO)
        console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        ch.setFormatter(console_formatter)
        logger.addHandler(ch)
    
        return logger

    def setup_experiment(self):
        self.name = self.args.exp
        self.exp_dir = os.path.join(self.args.record_root_dir, self.name)
        self.logger.info(f'seed is: {self.args.seed}')
        self.logger.info(f'args: {self.args}')

    def set_random_seeds(self):
        seed = random.randint(0, 100000)
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)
        self.args.seed = seed

    def get_accuracy(self, pred, labels):
        predicted_labels = (pred > 0.5).int()
        correct_predictions = (predicted_labels == labels).sum().item()
        total_predictions = labels.size(0)
        accuracy = correct_predictions / total_predictions
        return accuracy

    def concate_5120(self, imgs):
        # This function should include model evaluation calls
        # Assuming these models are attributes of the class
        with torch.no_grad():
            pred1, _ = self.globules_net(imgs)
            pred2, _ = self.negative_net(imgs)
            pred3, _ = self.milia_net(imgs)
            pred4, _ = self.pigment_net(imgs)
            pred5, _ = self.streaks_net(imgs)
            return torch.cat((pred1, pred2, pred3, pred4, pred5), 1)

    def train_or_val(self, mode='train'):
        
        assert mode in ['train', 'val'], "Mode must be 'train' or 'val'"
        dataloader = self.dataloader[mode]
        num_steps = len(dataloader)
        self.model.train() if mode == 'train' else self.model.eval()

        loss_meter = AverageMeter()
        acc_meter = AverageMeter()
        data_time_meter = AverageMeter()
        batch_time_meter = AverageMeter()
        all_preds = []
        all_labels = []

        start = time.time()
        for batch_id, (img_ids, imgs, labels, masks) in enumerate(dataloader):
            data_time = time.time() - start
            if batch_id >= num_steps: break

            imgs, labels = imgs.to(self.args.device), labels.to(self.args.device)
            
            if mode == 'train':
                self.optimizer.zero_grad()

            concate_input = self.concate_5120(imgs)
            pred = self.model(concate_input)

            
            loss = self.criterion(pred, labels)
            if mode == 'train':
                loss.backward()
                self.optimizer.step()

            loss_meter.update(loss.item(), len(imgs))
            acc = self.get_accuracy( pred.detach().cpu(), labels.cpu() )
            acc_meter.update(acc, len(imgs))
            batch_time = time.time() - start

            data_time_meter.update(data_time)
            batch_time_meter.update(batch_time)
            all_preds.append(pred.detach().cpu())
            all_labels.append(labels.cpu())

            # if (batch_id+1) % self.args.report_freq == 0:
            self.log_epoch_stats(mode.upper(), batch_id, num_steps, loss_meter.avg, acc_meter.avg, 0, data_time_meter.avg, batch_time_meter.avg)

            start = time.time()

        all_labels = torch.cat(all_labels, dim=0)
        all_preds = torch.cat(all_preds, dim=0)
        self.auc_scores = roc_auc_score( np.array(all_labels), np.array(all_preds), average=None)
        self.log_auc_scores(self.auc_scores)

        return loss_meter.avg, acc_meter.avg, self.auc_scores.mean()

    def log_epoch_stats(self, phase, batch_id, num_steps, avg_loss, avg_acc, mean_auc, avg_data_time, avg_batch_time):
        self.logger.info(f'{phase} Epoch: {self.current_epoch} Step: {batch_id}/{num_steps} Loss: {avg_loss:.4f} Accuracy: {avg_acc:.4f} AUC: {mean_auc:.4f} Data Time: {avg_data_time:.4f} Batch Time: {avg_batch_time:.4f}')

    def log_auc_scores(self, auc_scores):
        if isinstance(auc_scores, np.ndarray):
            for i, score in enumerate(auc_scores):
                self.logger.info(f'AUC Score {i}: {score:.4f}')
        else:
            self.logger.info(f'AUC Score: {auc_scores:.4f}')

    def save_model(self, save_path):
        torch.save(self.model.state_dict(), save_path)

    def train(self):
        for epoch in range(self.current_epoch, self.args.num_epochs):
            self.current_epoch = epoch
            self.logger.info(f'Starting Epoch: {epoch + 1}')
            train_loss, train_acc, train_auc = self.train_or_val('train')
            val_loss, val_acc, val_auc = self.train_or_val('val')

            if val_loss < self.best_loss:
                self.best_loss = val_loss
                self.best_model = copy.deepcopy(self.model)
                self.best_epoch = epoch + 1
                self.logger.info(f'New best model at epoch: {epoch + 1}')

                save_path = os.path.join( self.args.record_dir, self.args.save_name + '_best.pt')
                self.save_model(save_path)

            save_path = os.path.join( self.args.record_dir, self.args.save_name + '.pt')
            self.save_model(save_path)

In [32]:
dataloaders = get_dataloader( args )

In [33]:
trainer = Trainer(LogisticRegression(), 
                  globules_net , milia_net, negative_net, pigment_net, streaks_net,
                  dataloaders, 
                  args,
                 )

2024-06-25 19:06:50,674 - INFO - seed is: 1111
2024-06-25 19:06:50,677 - INFO - args: Namespace(report_freq=1, tune_mode='fine-tune', backbone='resnet50', cls_type='single', hidden_dim=512, record_root_dir='./record-data', exp='default_exp', batch_size=120, num_workers=0, num_classes=1, num_epochs=100, num_steps=-1, learn_rate=1e-06, weight_decay=1e-06, resume=False, input_dir='../h5py/', save_name='Resnet50', device='cuda:0', gpu_ids=[0], seed=1111, root='/scratch/rs37890/CARC/Explainable-NN-model', sub_root='/Fold1_Bio-Unet-stage1', checkpoint_root='/F(CLR) + F(Res) + F(Seg)', shallow_network_folder='/Mel-detection_effect_of_different_lr/1e-6', csv_h5_dir='/scratch/rs37890/CARC/Explainable-NN-model/Data', record_dir='/scratch/rs37890/CARC/Explainable-NN-model/Fold1_Bio-Unet-stage1/Mel-detection_effect_of_different_lr/1e-6/record_dir')


In [None]:
trainer.train()

2024-06-25 19:06:50,685 - INFO - Starting Epoch: 1
  return F.conv2d(input, weight, bias, self.stride,
2024-06-25 19:06:54,352 - INFO - TRAIN Epoch: 0 Step: 0/18 Loss: 0.7172 Accuracy: 0.2333 AUC: 0.0000 Data Time: 1.7792 Batch Time: 3.6651
2024-06-25 19:06:56,263 - INFO - TRAIN Epoch: 0 Step: 1/18 Loss: 0.7087 Accuracy: 0.3042 AUC: 0.0000 Data Time: 1.3422 Batch Time: 2.7869
2024-06-25 19:06:58,647 - INFO - TRAIN Epoch: 0 Step: 2/18 Loss: 0.7009 Accuracy: 0.4306 AUC: 0.0000 Data Time: 1.2373 Batch Time: 2.6523
2024-06-25 19:07:01,357 - INFO - TRAIN Epoch: 0 Step: 3/18 Loss: 0.6935 Accuracy: 0.5167 AUC: 0.0000 Data Time: 1.3651 Batch Time: 2.6658
2024-06-25 19:07:03,539 - INFO - TRAIN Epoch: 0 Step: 4/18 Loss: 0.6850 Accuracy: 0.5833 AUC: 0.0000 Data Time: 1.2981 Batch Time: 2.5688
2024-06-25 19:07:05,970 - INFO - TRAIN Epoch: 0 Step: 5/18 Loss: 0.6779 Accuracy: 0.6208 AUC: 0.0000 Data Time: 1.2550 Batch Time: 2.5456
2024-06-25 19:07:08,653 - INFO - TRAIN Epoch: 0 Step: 6/18 Loss: 0.67

In [None]:
trainer.auc_scores