In [3]:
import sys; sys.path.append("..")
import warnings; warnings.filterwarnings('ignore')

from core import * 
from data_manipulation import Transform, RandomRotation, Flip, RandomCrop, normalize_imagenet, normalize_mura, center_crop
from utils import save_model, load_model, lr_loss_plot
from architectures import DenseNet121
from train_functions import OptimizerWrapper, TrainingPolicy, FinderPolicy, validate_multilabel, lr_finder, validate_binary, TTA_binary
import json

SEED = 42
R_PIX = 8
IDX = 10 # Emphysema
BATCH_SIZE = 16
EPOCHS = 30
TRANSFORMATIONS = [RandomRotation(arc_width=20), Flip(), RandomCrop(r_pix=R_PIX)]
NORMALIZE = True # ImageNet
FREEZE = True
GRADUAL_UNFREEZING = True
n_samples = [50,100,200,400,600,800,1000]



BASE_PATH = Path('../..')
PATH = BASE_PATH/'data'
CHESTXRAY_FOLDER = PATH/'ChestXRay-250'
CHEXPERT_FOLDER = PATH/'ChesXPert-250'

SAVE_DIRECTORY = Path('./models')

## Load data

In [14]:
# Supervised
chesxray_train_df = pd.read_csv(PATH/"train_df.csv")
chesxray_valid_df = pd.read_csv(PATH/"val_df.csv")
chesxray_test_df = pd.read_csv(PATH/"test_df.csv")

# Unsupervised
chexpert_train_df = pd.read_csv(PATH/"CheXpert-v1.0-small/train.csv")
chexpert_valid_df = pd.read_csv(PATH/"CheXpert-v1.0-small/valid.csv")
chexpert_train_df = chexpert_train_df[chexpert_train_df['Frontal/Lateral']=="Frontal"]
chexpert_valid_df = chexpert_valid_df[chexpert_valid_df['Frontal/Lateral']=="Frontal"]

# Data frame labeled data subsetting

In [16]:
def decode_labels(df_col):
    return np.array(list(map(np.array, df_col.str.split(' ')))).astype(int)

def subset_df(df, amt=None, idx=IDX):
    
    lbls = decode_labels(df.Label)
    
    if amt is None: amt=2*lbls[:,idx].sum()
        
#     df.Label = lbls[:,idx].astype(int)
    pos_idxs = lbls[:,idx].astype(bool)

    neg = df[~pos_idxs].sample(n=amt//2, replace=False)
    pos = df[pos_idxs].sample(n=amt//2, replace=False)

    return pd.concat([neg, pos]).reset_index(drop=True)

# Datasets

In [39]:
class LabeledDataSet(Dataset):
    """
    Basic Images DataSet

    Args:
        dataframe with data: image_file, label
    """

    def __init__(self, df, image_path, idx):
        self.image_files = df["ImageIndex"].values
        self.lables = np.array([obs.split(" ")[idx]
                                for obs in df.Label]).astype(np.float32)
        self.image_path = image_path

    def __getitem__(self, index):
        path = self.image_path / self.image_files[index]
        x = cv2.imread(str(path)).astype(np.float32)
        x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB) / 255
        y = self.lables[index]
        return x, y

    def __len__(self):
        return len(self.image_files)
    
    
class UnlabeledDataSet(Dataset):
    """
    Basic Images DataSet

    Args:
        dataframe with data: image_file, label
    """

    def __init__(self, df, image_path, N):
        self.image_files = ['_'.join(p.split('/')[1:]) for p in df["Path"].values]
        self.image_path = image_path
        self.N = N
        self._replace = True if N > len(self.image_files) else False
        
        self.randomize()
        

    def __getitem__(self, index):
        path = self.image_path / self.iter_image_files[index]
        x = cv2.imread(str(path)).astype(np.float32)
        x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB) / 255

        return x, None

    def __len__(self):
        return self.N
    
    def randomize(self): self.iter_image_files = np.random.choice(self.image_files, size=self.N, replace=self._replace)

# Transformations

In [10]:
class UnlabeledTransform():
    """ Rotates an image by deg degrees

    Args:

        dataset: A base torch.utils.data.Dataset of images
        transforms: list with all the transformations involving randomnes

        Ex:
            ds_transform = Transform(ds, [random_crop(240, 240), rotate_cv()])

    """

    def __init__(self, dataset, transforms=None, normalize=True, seed=42, r_pix=8):
        self.dataset, self.transforms = dataset, transforms

        if normalize is True: self.normalize = normalize_imagenet
        elif normalize=='MURA': self.normalize = normalize_mura
        else: self.normalize = False

        self.center_crop = partial(center_crop, r_pix=r_pix)

        np.random.seed(seed)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        """
        Do transformation when image is called.
        We are assuming the trainingvalidation set is read from a folder of images already
        noramlized and resized to before random-crop and after random-crop sizes respectively.

        """
        data, label = self.dataset[index]
        
        out = np.copy(data)

        if self.transforms:
            for choices, f in list(zip(self.choices, self.transforms)):
                args = {k: v[index] for k, v in choices.items()}
                out = f(out, **args)
        else:
            out=self.center_crop(im=out)
        
        data = self.center_crop(data)

        if self.normalize: 
            out = self.normalize(out)
            data = self.normalize(data)
            
        return np.rollaxis(out, 2), np.rollaxis(data, 2)
    
    def randomize(self): self.dataset.randomize()

    def set_random_choices(self):
        """
        To be called at the begining of every epoch to generate the random numbers
        for all iterations and transformations.
        """
        self.choices = []
        x_shape = self.dataset[0][0].shape
        N = len(self)

        for t in self.transforms:
            self.choices.append(t.set_random_choices(N, x_shape))
  

# Wrapper & DataLoader

In [20]:
class DataBatches:
    '''
    Creates a dataloader using the specificed data frame with the dataset corresponding to "data".
    '''

    def __init__(self, df, transforms, shuffle, img_folder_path, idx=IDX, batch_size=16, num_workers=8,
                 drop_last=False, r_pix=8, normalize=True, seed=42, problem_type='supervised', N=None):

        if problem_type=='supervised':
            self.dataset = Transform(LabeledDataSet(df, image_path=img_folder_path, idx=idx),
                                     transforms=transforms, normalize=normalize, seed=seed, r_pix=r_pix)
        elif problem_type=='unsupervised':
            self.dataset = UnlabeledTransform(UnlabeledDataSet(df, image_path=img_folder_path, N=N),
                                     transforms=transforms, normalize=normalize, seed=seed, r_pix=r_pix)
        self.dataloader = DataLoader(
            self.dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
            shuffle=shuffle, drop_last=drop_last
        )

    def __iter__(self): return ((x.cuda().float(), y.cuda().float()) for (x, y) in self.dataloader)

    def __len__(self): return len(self.dataloader)

    def set_random_choices(self):
        if hasattr(self.dataset, "set_random_choices"): self.dataset.set_random_choices()


In [None]:
# labeled_training = DataBatches(chesxray_train_df_balanced, TRANSFORMATIONS, idx=IDX, shuffle=True, img_folder_path=CHESTXRAY_FOLDER, batch_size=16, num_workers=8,
#                  drop_last=False, r_pix=8, normalize=True, seed=42, problem_type='supervised')

# labeled_training.set_random_choices()
# x,y = next(iter(labeled_training))
# print(x.shape, y.shape)

In [None]:
# unlabeled_training = DataBatches(chexpert_train_df, TRANSFORMATIONS, shuffle=True, img_folder_path=CHEXPERT_FOLDER, batch_size=16, num_workers=8,
#                  drop_last=False, r_pix=8, normalize=True, seed=42, problem_type='unsupervised')

# unlabeled_training.set_random_choices()
# x,y = next(iter(unlabeled_training))
# print(x.shape, y.shape)

# Training

In [42]:
def kl_divergence_with_logits(logit, logit_t):
    p = F.softmax(logit, dim=1)
    log_p = F.log_softmax(logit, dim=1)
    log_q = F.log_softmax(logit_t, dim=1)
    kl = (p * (log_p - log_q)).sum(1).mean()
    return kl

def train(n_epochs, train_dl, unsuper_dl, valid_dl, model, max_lr=.01, wd=0, alpha=1./ 3,
          save_path=None, unfreeze_during_loop:tuple=None):
    
    if unfreeze_during_loop:
        total_iter = n_epochs*len(train_dl)
        first_unfreeze = int(total_iter*unfreeze_during_loop[0])
        second_unfreeze = int(total_iter*unfreeze_during_loop[1])

    best_loss = np.inf
    cnt = 0
    
    policy = TrainingPolicy(n_epochs=n_epochs, dl=train_dl, max_lr=max_lr)
    optimizer = OptimizerWrapper(model, policy, wd=wd, alpha=alpha)

    for epoch in tqdm_notebook(range(n_epochs), ):
        model.train()
        agg_div = 0
        agg_loss = 0
        train_dl.set_random_choices()
        unsuper_dl.set_random_choices()
        for (x,y), (x1,x2) in tqdm_notebook(zip(train_dl, unsuper_dl), leave=False):

            if unfreeze_during_loop:
                if cnt == first_unfreeze: model.unfreeze(1)
                if cnt == second_unfreeze: model.unfreeze(0)

#           first loss
            out = model(x)
            loss = F.binary_cross_entropy_with_logits(input=out.squeeze(), target=y)
            
            
#           second loss
            x = x1
            with torch.no_grad(): logit1 = model(x)
            logit2 = model(x2)
            loss += kl_divergence_with_logits(logit1, logit2)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            batch = y.shape[0]
            agg_loss += batch*loss.item()
            agg_div += batch
            cnt += 1


        val_loss, measure, _ = validate_binary(model, valid_dl)
        print(f'Ep. {epoch+1} - train loss {agg_loss/agg_div:.4f} -  val loss {val_loss:.4f} AUC {measure:.4f}')

        if save_path and val_loss < best_loss:
            save_model(model, save_path)
            best_loss = val_loss

In [22]:
chesxray_valid_df_balanced = subset_df(chesxray_valid_df, amt=None, idx=IDX)
chesxray_test_df_balanced = subset_df(chesxray_test_df, amt=None, idx=IDX)

N = 50
chesxray_train_df_balanced = subset_df(chesxray_train_df, amt=N, idx=IDX)

In [40]:
labeled_training = DataBatches(chesxray_train_df_balanced, TRANSFORMATIONS, idx=IDX, shuffle=True, img_folder_path=CHESTXRAY_FOLDER, batch_size=16, num_workers=8,
                 drop_last=False, r_pix=8, normalize=True, seed=42, problem_type='supervised')

labeled_validation = DataBatches(chesxray_valid_df_balanced, None, idx=IDX, shuffle=False, img_folder_path=CHESTXRAY_FOLDER, batch_size=16, num_workers=8,
                 drop_last=False, r_pix=8, normalize=True, seed=42, problem_type='supervised')

unlabeled_training = DataBatches(chexpert_train_df, TRANSFORMATIONS, shuffle=True, img_folder_path=CHEXPERT_FOLDER, batch_size=16, num_workers=8,
                 drop_last=False, r_pix=8, normalize=True, seed=42, problem_type='unsupervised', N = N)

In [41]:
pretrained = True
freeze = True

dn121 = DenseNet121(1, pretrained=pretrained, freeze=freeze).cuda()

train(10, train_dl=labeled_training, unsuper_dl=unlabeled_training, valid_dl=labeled_validation, model=dn121, max_lr=.01, wd=0, alpha=1./ 3,
          save_path=None, unfreeze_during_loop=(.1,.3))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Ep. 1 - train loss 0.7352 -  val loss 0.6716 AUC 0.6383


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Ep. 2 - train loss 0.6793 -  val loss 2.3522 AUC 0.5693


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Ep. 3 - train loss 1.0804 -  val loss 21.7729 AUC 0.5334


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Ep. 4 - train loss 0.7966 -  val loss 438.3789 AUC 0.5119


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Ep. 5 - train loss 0.6626 -  val loss 3.1003 AUC 0.5437


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Ep. 6 - train loss 0.6038 -  val loss 7.9705 AUC 0.5223


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Ep. 7 - train loss 0.5716 -  val loss 0.9217 AUC 0.5812


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Ep. 8 - train loss 0.5750 -  val loss 1.1793 AUC 0.6011


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Ep. 9 - train loss 0.5751 -  val loss 0.7011 AUC 0.6230


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Ep. 10 - train loss 0.5358 -  val loss 0.7029 AUC 0.6140
