In [1]:
import os
import pickle
import torch
import sys

from PIL import Image
from skimage.io import imread
from skimage.transform import resize
from matplotlib import pyplot as plt
from tabulate import tabulate

In [2]:
# Mount drive
from google.colab import drive
drive.mount('/content/drive')

# fix the path
original_path = os.getcwd()
sys.path.append(os.path.join('.', '..'))
sys.path.append('/content/drive/My Drive/Deep_Learning_Project12/')
os.chdir(sys.path[-1])

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Import Data and Wrangling

In [3]:
import numpy as np
import pandas as pd

data_files = os.listdir("Files")
  
labels = pd.read_csv("Files/dermx_labels.csv")
labels["image_path"] = [os.path.join(os.getcwd(),"Files", "images", f"{x}.jpeg") for x in labels["image_id"]]
labels.drop(columns = "Unnamed: 0", inplace = True)

labels.dropna().reset_index(drop = True)
labels = pd.get_dummies(labels, columns = ["area"])
labels["open_comedo"] = (labels["open_comedo"] > 0).astype(int)

features_target = pd.read_csv("Files/diseases_characteristics.csv")
features_target.rename(columns={"Unnamed: 0":"disease"},inplace=True)

# create on_hot for diagnosis and get features
one_hot = pd.get_dummies(labels["diagnosis"])
one_hot_encoding = [list(x) for x in one_hot.values]

labels["ts"] = one_hot_encoding

# get features as multi hot
features_touse = list(labels.columns[list(range(2,9)) + [10,11,12,13]])
labels["features"] = labels.loc[:, features_touse].values.tolist()

# map feature sequences to value
features_map = {}
for idx, feat in enumerate(labels["features"].apply(tuple).unique()):
  features_map[str(feat)] = idx

labels["features_label"] = labels["features"].apply(tuple).apply(str).map(features_map)

# get domain
domain = pd.read_csv("Files/diseases_characteristics.csv")
domain.rename(columns={"Unnamed: 0":"diagnosis"},inplace=True)
domain = pd.get_dummies(domain, columns = ["area"])
same_sort = ["diagnosis"] + features_touse
domain = domain[same_sort]  # same sorting

domain_one_hot = pd.get_dummies(domain["diagnosis"])

domain_one_hot_encoding = [list(x) for x in domain_one_hot.values]
domain["ts"] = domain_one_hot_encoding
feature_cols = domain.columns[1:12]
domain["features"] = domain.loc[:,feature_cols].values.tolist()

# add domain features (domain knowledge) to dataframe
tf = []
for i, row in labels.iterrows():
  disease = row["diagnosis"]
  true_features = domain.loc[domain.diagnosis == disease].features.tolist()[0]
  tf.append(true_features)
labels["domain_features"] = tf 

domain = domain.sort_values(by="diagnosis").reset_index(drop=True)

data = labels.copy()


In [4]:
def add_domain(df: pd.DataFrame):
  domain = pd.read_csv("Files/diseases_characteristics.csv")
  domain.rename(columns={"Unnamed: 0":"diagnosis"},inplace=True)
  domain = pd.get_dummies(domain, columns = ["area"])
  same_sort = list(labels.columns[list(range(1,9)) + [10,11,12,13]])
  domain = domain[same_sort]  # same sorting

  domain_one_hot = pd.get_dummies(domain["diagnosis"])

  domain_one_hot_encoding = [list(x) for x in domain_one_hot.values]
  domain["ts"] = domain_one_hot_encoding
  feature_cols = domain.columns[1:12]
  domain["features"] = domain.loc[:,feature_cols].values.tolist()

  # add domain features (domain knowledge) to dataframe
  tf = []
  for i, row in df.iterrows():
    disease = row["diagnosis"]
    true_features = domain.loc[domain.diagnosis == disease].features.tolist()[0]
    tf.append(true_features)
  df["domain_features"] = tf 

  return df

# Some Useful Functions

In [5]:
#@title 
from HelperFunctions.project_utils import Tracker
from sklearn.utils import class_weight
import ast

def add_no_match(df: pd.DataFrame):
  
  unique_data = [list(x) for x in set(tuple(x) for x in df.domain_features)]

  app = []
  for i, row in df.iterrows():
    for x in unique_data:
      tmp_row = row.copy()
      if tmp_row["domain_features"] == x:
        pass
      else:
        tmp_row["diagnosis"] = "no_match"
        tmp_row["domain_features"] = x
        app.append(tmp_row)

  # Create new data frame
  updated_df=df.append(app,ignore_index=True)
  
  # Update targets "ts"
  updated_df.drop(columns="ts")
  new_dummies = pd.get_dummies(updated_df["diagnosis"])
  new_dummies = [list(x) for x in new_dummies.values]
  updated_df["ts"] = new_dummies

  return updated_df

def unique_lists(data: list):
  return [list(x) for x in set(tuple(x) for x in data)]

def map_domain_knowledge(df: pd.DataFrame):
  keys = df.diagnosis.unique().tolist()
  map = dict()
  for k in keys:
    map[k] = df.loc[data["diagnosis"] == k].domain_features.tolist()[0]
  return map

def plt_tracker(tracker: Tracker, num_epoch):
    plt.figure(figsize=(14,8))
    epoch_ticks = range(0,num_epoch + 1, 5)

    # loss
    plt.subplot(1,2,1)
    plt.plot(tracker.train_iter, tracker.train_loss, label='Training loss')
    plt.plot(tracker.val_iter, tracker.val_loss, label='Validation loss')
    plt.title("Loss")
    plt.ylabel("Loss"), plt.xlabel("Epoch")
    plt.xticks(epoch_ticks)
    plt.legend()
    plt.grid()

    # acc
    plt.subplot(1,2,2)
    plt.plot(tracker.train_iter, tracker.train_acc, label='Training accuracy')
    plt.plot(tracker.val_iter, tracker.val_acc, label='Validation accuracy')
    plt.title("Accuracy")
    plt.ylabel("Accuracy"), plt.xlabel("Epoch")
    plt.xticks(epoch_ticks)
    plt.legend()
    plt.grid()

    plt.tight_layout()
    plt.show()


def calc_multiclass_weights(df: pd.DataFrame, device):
  
  cls = sorted(df.diagnosis.unique())
  y = df.diagnosis.to_list()
  csw = class_weight.compute_class_weight('balanced', classes = cls, y = y)
  class_weights = torch.tensor(csw,dtype=torch.float).to(device)

  return class_weights

def feature_intersect(domain, features):
    dom_feat = np.asarray(domain)
    curr_feat = np.asarray(features)
    
    ones=np.intersect1d(np.where(dom_feat==1), np.where(curr_feat==1))
    intersect = np.zeros(len(dom_feat),dtype=int)
    intersect[ones] = 1

    return intersect

def read_splits(path):
  return pd.read_csv(path, converters={1:ast.literal_eval,
                                       2:ast.literal_eval})


# Define Dataset Class for features


In [6]:
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split

class DomainDatset(Dataset):
  def __init__(self, data: pd.DataFrame, augment=True):
    
    dictator = "features_label"
    if augment:
      sample_count = {}
      up_sampler = np.unique(data[dictator])
      for f in up_sampler:
          sample_count[f] = np.count_nonzero(data[dictator] == f)

      maxcount = np.max(list(sample_count.values()))
      for f in up_sampler:
          gapnum = maxcount - sample_count[f]
          temp_df = data.iloc[np.random.choice(np.where(data[dictator] == f)[0], size = gapnum)]
          data = data.append(temp_df, ignore_index = True)

    self.dataframe = data

    self.data_input = data["features"].reset_index(drop=True)
    self.domain_input = data["domain_features"].reset_index(drop=True)
    self.target = data["ts"].reset_index(drop=True)

  def __len__(self):
    return (len(self.data_input))

  def __getitem__(self, i):
    
    target = self.target[i]
    
    domain_input = self.domain_input[i]
    data_input = self.data_input[i]
    
    #input = [*data_input, *domain_input] 
    #input = feature_intersect(domain_input, data_input)
    input = np.array(domain_input) + np.array(data_input)

    return torch.tensor(input, dtype=torch.float), torch.tensor(target, dtype=torch.long)

# Model `DomainNet` for learning Domain Knowledge

In [7]:
# create the MTL network
from torch import nn
from torch import optim
import torchvision.models as models


class DomainNet(nn.Module):

    def __init__(self, num_classes, num_features, num_hidden = 256):
        super(DomainNet, self).__init__()

        self.num_classes = num_classes
        self.num_features = num_features
        
        self.layer_1 = nn.Sequential(
            nn.Linear(in_features=num_features, out_features=num_hidden),
            nn.ReLU(),
        )

        self.layer_2 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(in_features=num_hidden, out_features=num_classes)
        )

        # self.layer_2 = nn.Sequential(
        #     nn.Dropout(0.5),
        #     nn.Linear(in_features=num_hidden, out_features=num_hidden),
        #     nn.ReLU()
        # )

        # self.layer_3 = nn.Sequential(
        #     nn.Dropout(0.5),
        #     nn.Linear(in_features=num_hidden, out_features=num_classes)
        # )

    def forward(self, x):

      x = self.layer_1(x)

      x = self.layer_2(x)

      # x = self.layer_3(x)

      return x
    

# Dataset for MTL Net


In [8]:
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split

class NaturalImageDataset(Dataset):
  def __init__(self, data, augment = False, load_img=True, dictator="features_label"):

    dictator = 'features_label'    # What variable we use to upsample to match
    # upsample if augment
    if augment:
      sample_count = {}
      up_sampler = np.unique(data[dictator])
      for f in up_sampler:
          sample_count[f] = np.count_nonzero(data[dictator] == f)

      maxcount = np.max(list(sample_count.values()))
      for f in up_sampler:
          gapnum = maxcount - sample_count[f]
          temp_df = data.iloc[np.random.choice(np.where(data[dictator] == f)[0], size = gapnum)]
          data = data.append(temp_df, ignore_index = True)
      

    self.dataframe = data
    self.imgage_path = data["image_path"].values
    self.labels = data["ts"].values
    self.features = data["features"].values

    # transform image
    if augment:
      self.transform = transforms.Compose([
                                  transforms.Resize(256),
                                  transforms.CenterCrop(224),
                                  transforms.ToTensor(),
                                  transforms.RandomHorizontalFlip(p = 0.5),
                                  transforms.RandomVerticalFlip(p=0.5),
                                  transforms.ColorJitter(brightness = 0.1, contrast = 0.1),
                                  transforms.RandomAffine(degrees = 50, translate = (0.1, 0.1), scale = (0.9, 1.1)),
                                  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                              ])
    else:
      self.transform = transforms.Compose([
                                  transforms.Resize(256),
                                  transforms.CenterCrop(224),
                                  transforms.ToTensor(),
                                  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                              ])


    if load_img:
      self.images = [self.transform(Image.open(img_path)) for img_path in tqdm(data["image_path"])]

  def __len__(self):
    return (len(self.images))

  def __getitem__(self, i):
    image = self.images[i]
    label = self.labels[i]
    feature = self.features[i]
    return image, torch.tensor(label, dtype=torch.long), torch.tensor(feature, dtype=torch.long)

# Model `MTLNet`

In [9]:
# create the MTL network
from torch import nn
from torch import optim
import torchvision.models as models

class MTLNet(nn.Module):
    def __init__(self, num_classes, num_features):
        super(MTLNet, self).__init__()

        self.num_classes = num_classes
        self.num_features = num_features

        # modify resnet
        base_net = models.resnet50(pretrained=True)

        # Freeze all parameters of base network
        for param in base_net.parameters():
          param.requires_grad = False

        # Freeze all parameters of base network
        for param in base_net.layer4.parameters():
          param.requires_grad = True

        # Unfreeze all bn params
        for module in base_net.modules():
          if isinstance(module, nn.BatchNorm2d):
            for param in module.parameters():
              param.requires_grad = True
                

        # get head infeatures
        head_in = base_net.fc.in_features
        
        # Exclude fc layer
        base_layers = list(base_net.children())
        base_net = nn.Sequential(*base_layers[:-1])

        # construct the base model
        self.base_model = nn.Sequential(
            base_net
        )

        # labels head part
        self.labels_head = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Flatten(),
            nn.Linear(in_features = head_in, out_features = num_classes, bias=True),

        )

        # labels head part
        self.features_head = nn.Sequential(
            nn.Dropout(p=0.3),
            nn.Flatten(),
            nn.Linear(in_features = head_in, out_features = num_features, bias=True)
        )


    def forward(self, x):

        # common part
        x = self.base_model(x)
        
        # # flatten dimensions
        # x = torch.flatten(x, 1) 

        # labels head part
        x_labels = self.labels_head(x)

        # features head part
        x_features = self.features_head(x)

        return x_labels, x_features
    

# Define train loop for `DomainNet`

In [10]:
# Train the net
from HelperFunctions.project_utils import Tracker, plot_tracker
from sklearn.metrics import accuracy_score, f1_score

def train_domain_net(net: DomainNet, criterion, optimizer, device,
                     trainloader: DataLoader, validationloader: DataLoader = None,
                     validation_on: bool = False, num_epoch = 100, eval_every = 3,
                     plt_on: bool = False):

  # Initialize tracker
  tracker = Tracker()

  for epoch in range(num_epoch):  
    #print("\r",end=f"Epoch: {epoch}", flush=True)
    # Train
    net.train()
    for i, x in enumerate(trainloader):
      input_batch, targets = x
      input_batch, targets = input_batch.to(device), targets.to(device)

      output = net(input_batch)

      # labels ------------------------------------------------------------
      true_class = torch.argmax(targets,dim=1)
      probabilities = nn.functional.softmax(output, dim = 1) 
      preds = torch.argmax(probabilities,dim=1)
      
      loss = criterion(output, true_class)
      tracker.batch_loss.append(loss.item() / input_batch.size(0))

      acc = f1_score(true_class.cpu(), preds.cpu(), average='weighted')
      tracker.batch_acc.append(acc)

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()  

    # Update training values with batch results
    tracker.train_update(epoch)

    # Validate
    if validation_on & ((epoch % eval_every == 0) | (epoch == num_epoch - 1)):
      net.eval() 
      with torch.no_grad(): 

        for i, v in enumerate(validationloader):
    
          input_batch, targets = x
          input_batch, targets = input_batch.to(device), targets.to(device)

          output = net(input_batch)

          true_class = torch.argmax(targets,dim=1)
          probabilities = nn.functional.softmax(output, dim = 1) 
          preds = torch.argmax(probabilities,dim=1)
          
          loss = criterion(output, true_class)
          tracker.batch_loss.append(loss.item() / input_batch.size(0))

          acc = f1_score(true_class.cpu(), preds.cpu(), average="weighted")
          tracker.batch_acc.append(acc)

      tracker.val_update(epoch)

  if plt_on: plt_tracker(tracker, num_epoch)
  return tracker



# Test `DomainNet`



In [11]:

def test_domain_net(net: DomainNet, testloader: DataLoader, device):

  test_probs = []
  test_preds = []
  test_targets = []

  for i, x in enumerate(testloader):
      input_batch, one_hot_target = x
      input_batch = input_batch.to(device)

      output = net(input_batch)

      targets = torch.argmax(one_hot_target,dim=1)
      probs = nn.functional.softmax(output, dim = 1) 
      preds = torch.argmax(probs,dim=1)

      test_probs = [*test_probs, *probs.cpu().detach().numpy()]
      test_preds = [*test_preds, *preds.cpu().detach().numpy()]
      test_targets = [*test_targets, *targets.cpu().detach().numpy()]

      return {"probs": test_probs, "preds": test_preds, "targets": test_targets}

    

# Functions

In [12]:
def plot_conf(conf,target_labels):
  df_cm = pd.DataFrame(conf, index = [i for i in target_labels],
                              columns = [i for i in target_labels])
  plt.figure(figsize = (7,4))
  sns.heatmap(df_cm, annot=True, cmap="Blues")
  plt.show()


# apply domain knowledge
def get_domain_probabilites(mtl_features, domain: pd.DataFrame):
  
  d_probs = np.zeros(shape=(NUM_CLASSES,NUM_CLASSES))
  
  for i, row in domain.iterrows():
    dom_feat = row.features
    
    #inp = torch.Tensor([*mtl_features, *dom_feat]).to(device)
    #inp = feature_intersect(domain = dom_feat, features = mtl_features)  # THIS!!!
    inp = torch.Tensor(np.array(dom_feat) + np.array(mtl_features)).to(device)

    output = domain_net(inp)
    curr_prob = nn.functional.softmax(output, dim = 0)
    d_probs[i] = curr_prob.cpu().detach().numpy()

  #plot_conf(d_probs,domain.diagnosis.values)
  return d_probs

def apply_domain_correction(mtl_probs, mtl_features, domain: pd.DataFrame):

  if max(mtl_probs) <= 1.0:

    domain_probs = get_domain_probabilites(mtl_features, domain)
    #domain_probs = np.diag(domain_probs)
    domain_probs = np.sum(domain_probs,axis=0)

    combined_probs = mtl_probs * domain_probs

    corrected_prediction = combined_probs.argmax(axis=0)
    corrected_prediciton_prob = combined_probs[corrected_prediction]

  else:
    return mtl_probs.argmax(axis=0), mtl_probs  

  return corrected_prediction, domain_probs.argmax(axis=0)

# Check corrections

In [37]:
from HelperFunctions.project_utils import KFoldResult
from sklearn.metrics import classification_report
import seaborn as sns


corr_results = []
corr_counts = np.zeros(shape=(5,6))
corr_f1 = []
mtl_f1 = []

test_len=[]

k = 5
for i in range(k):

  target_labels = sorted(data.diagnosis.unique())

  # COLLECT RUN VARIABLES
  k_name = f"K_fold/Correction_FINAL_kfold_NA_{i}.json"
  res = KFoldResult(k_name)
  test_len.append(len(res.test_idx))
  true_labels = res.test_labels_targets
  true_features = res.test_features_targets

  mtl_probs = res.labels_probs.numpy()
  mtl_preds = res.labels_preds
  corrected_preds = res.corrected_preds

  # Cast to np.array 
  correction_preds = np.array(corrected_preds)
  change = (mtl_preds!=correction_preds).astype(int)

  corr_df = pd.DataFrame()
  corr_df["disease"] = [target_labels[x] for x in true_labels]  # Targets as labels
  corr_df["target"] = true_labels                               # Targets as [0/1]
  corr_df["mtl_pred"] = mtl_preds                               # MTL predictions
  corr_df["corrected_pred"] = correction_preds                  # Correction prediction
  corr_df["change"] = change                                    # Did it change? [0/1]

  # Collect change effects:

  # CORRECT => CORRECT (GOOD)   [CHANGE = 0] (correct_nochange)   - TTnC
  # WRONG   => CORRECT (GOOD)   [CHANGE = 1] (correct_change)     - FTC
  # CORRECT => WRONG (BAD)      [CHANGE = 1] (incorrect_change)   - TFC
  # WRONG   => WRONG (BAD)      [CHANGE = 1] (incorrect_change)   - FFC
  # WRONG   => SAME  (BAD)      [CHANGE = 0] (incorrect_nochange) - FFnC


  corr_df["TTnC"] = ((mtl_preds == true_labels) & (correction_preds == true_labels) & (change==0)).astype(int)
  corr_df["FTC"]  = ((mtl_preds != true_labels) & (correction_preds == true_labels) & (change==1)).astype(int)
  corr_df["TFC"]  = ((mtl_preds == true_labels) & (correction_preds != true_labels) & (change==1)).astype(int)
  corr_df["FFC"]  = ((mtl_preds != true_labels) & (correction_preds != true_labels) & (change==1)).astype(int)
  corr_df["FFnC"]  = ((mtl_preds != true_labels) & (correction_preds != true_labels) & (change==0)).astype(int)



  total = corr_df.agg({"change": "sum",
                       "TTnC": "sum",
                        "FTC": "sum",
                        "TFC": "sum",
                        "FFC": "sum",
                        "FFnC": "sum"})


  corr_results.append(corr_df)
  corr_counts[i] = total.values
  corr_f1.append(f1_score(true_labels, correction_preds,average="weighted"))
  mtl_f1.append(f1_score(true_labels, mtl_preds,average="weighted"))
  
corr_df = pd.DataFrame(data=corr_counts, columns=["Change","TTnC","FTC","TFC","FFC","FnC"])
corr_df["F1 MTL"] = mtl_f1
corr_df["F1 Correction"] = corr_f1

tab=tabulate(corr_df,headers=corr_df.columns.to_list(),tablefmt="latex_raw",floatfmt=".2f")
print(tab)

# print(corr_counts.sum(axis=0))
print("\nMean and std:")
print(tabulate({
    " ": ["MTL", "Correction"],
    "mean": [np.mean(mtl_f1), np.mean(corr_f1)],
    "std": [np.std(mtl_f1), np.std(corr_f1)]
}, headers = "keys", floatfmt=".2f"))
# print(np.mean(mtl_f1))
# print(np.std(mtl_f1))
# print(np.mean(corr_f1))
# print(np.std(corr_f1))

\begin{tabular}{rrrrrrrrr}
\hline
    &   Change &   TTnC &   FTC &   TFC &   FFC &   FnC &   F1 MTL &   F1 Correction \\
\hline
  0 &     2.00 &  65.00 &  2.00 &  0.00 &  0.00 & 24.00 &     0.71 &            0.73 \\
  1 &     2.00 &  66.00 &  0.00 &  2.00 &  0.00 & 23.00 &     0.74 &            0.72 \\
  2 &     4.00 &  67.00 &  2.00 &  2.00 &  0.00 & 20.00 &     0.75 &            0.76 \\
  3 &     1.00 &  72.00 &  0.00 &  0.00 &  1.00 & 17.00 &     0.80 &            0.80 \\
  4 &     4.00 &  65.00 &  3.00 &  1.00 &  0.00 & 21.00 &     0.73 &            0.75 \\
\hline
\end{tabular}

Mean and std:
              mean    std
----------  ------  -----
MTL           0.75   0.03
Correction    0.75   0.03


In [31]:
cts = corr_df[corr_df.columns.to_list()[:-2]].to_numpy()
lens = np.array(test_len)
# print(cts)
# print(cts.transpose())
perc=cts.transpose()/lens
perc=perc.transpose()

# print(perc)
# print(perc.transpose()) 
mp = perc.mean(axis=0)
sd = perc.std(axis=0)
d = pd.DataFrame(perc, columns=corr_df.columns.to_list()[:-2])
tab=tabulate(perc,headers=d.columns.to_list())
print(tab)
print("\nmean:")
print(np.round(mp*100,2))
print("std:")
print(sd*100)

   Change      TTnC        FTC        TFC        FFC       FnC
---------  --------  ---------  ---------  ---------  --------
0.021978   0.714286  0.021978   0          0          0.263736
0.021978   0.725275  0          0.021978   0          0.252747
0.043956   0.736264  0.021978   0.021978   0          0.21978
0.0111111  0.8       0          0          0.0111111  0.188889
0.0444444  0.722222  0.0333333  0.0111111  0          0.233333

mean:
[ 2.87 73.96  1.55  1.1   0.22 23.17]
std:
[1.32693119 3.10074312 1.32849422 0.98289916 0.44444444 2.62558534]
