# **Importing Necessary Libraries**

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import *
from torch.utils.data import DataLoader
import numpy as np
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.models as models
import shutil
from torchvision.utils import draw_bounding_boxes

# **Define 4-layer CNN**

In [None]:
class SimpleCNN4(nn.Module):
    def __init__(self, ftr_size=512):
        super(SimpleCNN4, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(200704, ftr_size)

        self.flat = nn.Flatten()

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool2(x)

        x = self.conv3(x)
        x = F.relu(x)
        x = self.pool3(x)

        x = self.conv4(x)
        x = F.relu(x)
        x = self.flat(x)
        x = F.relu(self.fc1(x))

        return x

# **Define 8-layer CNN**

In [None]:
class SimpleCNN8(nn.Module):
    def __init__(self, ftr_size=512):
        super(SimpleCNN8, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv4 = nn.Conv2d(in_channels=128, out_channels=255, kernel_size=3, stride=1, padding=1)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv5 = nn.Conv2d(in_channels=255, out_channels=512, kernel_size=3, stride=1, padding=1)
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv6 = nn.Conv2d(in_channels=512, out_channels=720, kernel_size=3, stride=1, padding=1)

        self.conv7 = nn.Conv2d(in_channels=720, out_channels=1024, kernel_size=3, stride=1, padding=1)

        self.conv8 = nn.Conv2d(in_channels=1024, out_channels=2000, kernel_size=3, stride=1, padding=1)

        self.fc1 = nn.Linear(98000, ftr_size)

        self.flat = nn.Flatten()

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool2(x)

        x = self.conv3(x)
        x = F.relu(x)
        x = self.pool3(x)

        x = self.conv4(x)
        x = F.relu(x)
        x = self.pool4(x)

        x = self.conv5(x)
        x = F.relu(x)
        x = self.pool5(x)

        x = self.conv6(x)
        x = F.relu(x)

        x = self.conv7(x)
        x = F.relu(x)

        x = self.conv8(x)
        x = F.relu(x)

        x = self.flat(x) # flatten 3x5x2 matrix -----> 1x30 vector
        x = F.relu(self.fc1(x))

        return x

# **Definining Attention Network Class**

In [None]:
class AttentionNetwork(torch.nn.Module):

    def __init__(self, AttentionNet=None, AttentionFtrExtractor=None, GlobalNet=None, num_classes=2, size=224, ftr_size=512):
        super(AttentionNetwork, self).__init__()

        self.Attention_net = AttentionNet # network for computing attention maps

        self.Attention_ftr_extrcator = AttentionFtrExtractor # (local) feature extractor from attention maps

        self.global_net = GlobalNet # global feature extrcator

        out_size = num_classes
        if num_classes == 2:
          out_size = 1

        # network for aggregating global and local information
        self.aggregate_net = nn.Sequential(nn.Linear(ftr_size*2, ftr_size),
                             nn.Sigmoid(),
                             nn.Linear(ftr_size, out_size),
                             nn.Sigmoid())

        self.size = size

        # fully-connected layers for atttion maps and aggregate network
        self.Attention_FC = nn.Sequential(nn.Linear(ftr_size, 4), nn.Sigmoid())
        self.classifier_FC = nn.Linear(ftr_size, num_classes)

    # Function to draw the attention map boundaries on the image
    def draw_attention_map(self, img, x_min, y_min, x_max, y_max):
      images = []
      for i in range(img.shape[0]):
        img_uint8 = torch.round(255*img[i]).to(torch.uint8)
        bbox1 = [x_min[i].item(), y_min[i].item(), x_max[i].item(), y_max[i].item()]
        bbox = [bbox1]
        bbox = torch.tensor(bbox, dtype=torch.int)
        img_with_rect=draw_bounding_boxes(img_uint8, bbox,width=6,colors=[(255,0,0)],fill =False,font_size=20)
        img_with_rect = (img_with_rect / 255).to(torch.float32)
        images.append(img_with_rect)

      return images

    # Cropping and zooming into the attention region
    def crop_zoom(self, image, tx, ty, tl_x, tl_y):
        tx_r = (self.size * tx).int() # real tx (since 0 <= tx <= 1)
        ty_r = (self.size * ty).int()

        tl_x_r = ((self.size / 2) * tl_x + 1).int() # should be at least 1 pixel
        tl_y_r = ((self.size / 2) * tl_y + 1).int()

        x_min = (tx_r - tl_x_r).clamp(min=0)
        x_max = (tx_r + tl_x_r).clamp(max=self.size)
        y_min = (ty_r - tl_y_r).clamp(min=0)
        y_max = (ty_r + tl_y_r).clamp(max=self.size)

        # draw attention region on image
        imgs_with_rect = self.draw_attention_map(image, x_min, y_min, x_max, y_max)

        # crop and zoom into the attention region of each image in the batch
        for i in range(image.shape[0]):
          img = image[i][:, y_min[i]:y_max[i], x_min[i]:x_max[i]]
          img = F.interpolate(img[None, :, :, :], (224, 224), mode='bilinear')[0]
          image[i] = img

        return image, imgs_with_rect

    # aggregate two vectors by concatenating them
    def aggregate(self, vec1, vec2):
        return torch.cat((vec1, vec2), dim=1)

    # Function definining how the network processes an input batch of images
    def forward(self, image):
        # Apply attention, crop, and zoom
        vec = self.Attention_net(image)
        tx, ty, tl_x, tl_y = self.Attention_FC(vec).transpose(0,1)
        x_cropped, img_with_rect = self.crop_zoom(image, tx, ty, tl_x, tl_y)

        # Local and Global features
        ftr_vec_local = self.Attention_ftr_extrcator(x_cropped)
        ftr_vec_global = self.global_net(image)

        # Aggregate local and global features and process them through the
        # aggregating network to get class probabilities
        ftr_vec_final = self.aggregate(ftr_vec_local, ftr_vec_global)
        probs = self.aggregate_net(ftr_vec_final)

        return probs, img_with_rect


# **Mounting Google Drive to Read Datasets**

In [None]:
from google.colab import drive
import os
drive.mount('/content/drive')



Mounted at /content/drive


# **Define model-building function**

In [None]:
# Function to build the model given the our hyper-parameters
def build_model(CNN, weight_sharing, ftr_size, num_classes):
  if CNN == '8-layer' and weight_sharing:
    CNN_model = SimpleCNN8(ftr_size=ftr_size)
    attention_net = CNN_model
    attention_ftr_extractor = CNN_model
    global_net = CNN_model
  elif CNN == '8-layer' and not weight_sharing:
    attention_net = SimpleCNN8(ftr_size=ftr_size)
    attention_ftr_extractor = SimpleCNN8(ftr_size=ftr_size)
    global_net = SimpleCNN8(ftr_size=ftr_size)
  elif CNN == "ResNet18" and weight_sharing:
    CNN_model = models.resnet18(pretrained=True)
    fc_in_ftrs = CNN_model.fc.in_features
    CNN_fc = nn.Linear(fc_in_ftrs, ftr_size)
    CNN_model.fc = CNN_fc

    attention_net = CNN_model
    attention_ftr_extractor = CNN_model
    global_net = CNN_model
  elif CNN == "ResNet18" and not weight_sharing:
    CNN_model1 = models.resnet18(pretrained=True)
    fc_in_ftrs = CNN_model1.fc.in_features
    CNN_model1.fc = nn.Linear(fc_in_ftrs, ftr_size)

    CNN_model2 = models.resnet18(pretrained=True)
    CNN_model2.fc = nn.Linear(fc_in_ftrs, ftr_size)

    CNN_model3 = models.resnet18(pretrained=True)
    CNN_model3.fc = nn.Linear(fc_in_ftrs, ftr_size)

    attention_net = CNN_model1
    attention_ftr_extractor = CNN_model2
    global_net = CNN_model3
  elif CNN == "ResNet50" and weight_sharing:
    CNN_model = models.resnet50(pretrained=True)
    fc_in_ftrs = CNN_model.fc.in_features
    CNN_fc = nn.Linear(fc_in_ftrs, ftr_size)
    CNN_model.fc = CNN_fc

    attention_net = CNN_model
    attention_ftr_extractor = CNN_model
    global_net = CNN_model
  elif CNN == "ResNet50" and not weight_sharing:
    CNN_model1 = models.resnet50(pretrained=True)
    fc_in_ftrs = CNN_model1.fc.in_features
    CNN_model1.fc = nn.Linear(fc_in_ftrs, ftr_size)

    CNN_model2 = models.resnet50(pretrained=True)
    CNN_model2.fc = nn.Linear(fc_in_ftrs, ftr_size)

    CNN_model3 = models.resnet50(pretrained=True)
    CNN_model3.fc = nn.Linear(fc_in_ftrs, ftr_size)

    attention_net = CNN_model1
    attention_ftr_extractor = CNN_model2
    global_net = CNN_model3

  # Initialize the model and return it
  model = AttentionNetwork(AttentionNet=attention_net,
                           AttentionFtrExtractor=attention_ftr_extractor,
                           GlobalNet=global_net,
                           num_classes=num_classes,
                           size=224,
                           ftr_size=ftr_size)

  return model

# **Defining function that logs trial results**

In [None]:
# Log results in the file for which the file descriptor (fd) is given
def log_results(CNN, ftr_size, weight_sharing, fd, conf_mat, overall_acc, acc, sensitivity, specificity, percision, loss):
  fd.write("\n")
  for i in range(len(conf_mat)):
    for j in range(len(conf_mat[i])):
      fd.write(str(conf_mat[i][j]))
      if j == len(conf_mat[i]) - 1:
        fd.write('\n')
      else:
        fd.write(' ')

  param_string = CNN+' '+str(ftr_size)+' '+str(weight_sharing)+' '
  fd.write(param_string+' ')
  fd.write(str(overall_acc)+' ')
  for i in range(len(acc)):
    fd.write(str(acc[i])+' ')
  for i in range(len(sensitivity)):
    fd.write(str(sensitivity[i])+' ')
  for i in range(len(specificity)):
    fd.write(str(specificity[i])+' ')
  for i in range(len(percision)):
    fd.write(str(percision[i])+' ')
  fd.write(str(loss))
  fd.close()

# **Define function that computes evaluation metrics**

In [None]:
# Compute the evaluation metrics, accuracy, sensitivity, specificity, and percision,
# for each class
def eval_metrics(conf_mat):
  acc = [0, 0, 0]
  for i in range(3):
    acc[i] = round(100*conf_mat[i][i] / (sum(conf_mat[i])), 2)

  sensitivity = [0, 0, 0]
  for i in range(3):
    tp = conf_mat[i][i]
    fn = sum(conf_mat[i]) - tp
    if tp == 0 and fn == 0:
      sensitivity[i] = -1
    else:
      sensitivity[i] = round(100*tp / (tp + fn), 2)

  specificity = [0, 0, 0]
  for i in range(3):
    tn = sum([conf_mat[j][k] if j != i and k != i else 0 for j in range(3) for k in range(3)])
    fp = sum([conf_mat[j][i] if i != j else 0 for j in range(3)])
    if tn == 0 and fp == 0:
      specificity[i] = -1
    else:
      specificity[i] = round(100*tn / (tn + fp), 2)

  percision = [0, 0, 0]
  for i in range(3):
    tp = conf_mat[i][i]
    fp = sum([conf_mat[j][i] if i != j else 0 for j in range(3)])
    if tp == 0 and fp == 0:
      percision[i] = -1
    else:
      percision[i] = round(100*tp / (tp + fp), 2)
  return acc, sensitivity, specificity, percision

# **Defining Training Function**

In [None]:
def Train(data_dir=None, CNN='ResNet50', ftr_size=64, weight_sharing=True, out_file=None):

  # defining some training parameters
  batch_size = 4
  learning_rate = 0.00006
  num_classes = 3
  size = 224

  # If the log file is present apppend to it
  if out_file in os.listdir("/content/drive/MyDrive"):
    print("output file already exists")
    log_file = open("/content/drive/MyDrive/"+out_file, 'a')
  else:
    # If not, create it
    print("created log file")
    log_file = open("/content/drive/MyDrive/"+out_file, 'w')
    log_file.write("FORMAT: CNN multi_stage ftr_size weight_sharing augmented Acc M_acc N_acc K_acc M_sen N_sen K_sen M_spec N_spec K_spec M_per N_per K_per loss")
    log_file.write("\n\n\n")

  # Use GPU if available, use CPU otherwise
  if torch.cuda.is_available():
    print("cuda is available")
    device = torch.device("cuda")
  else:
    print("cuda is not available")
    device = 'cpu'

  # Define the transform to applied to each image
  transform = transforms.Compose(

    [

        transforms.ToTensor(),
        transforms.Resize((size, size), antialias=None)

    ])

  print("device:",device)

  # define the train, validation, and test datasets

  train_dataset = datasets.ImageFolder(root=data_dir+"/train", transform=transform)
  train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  val_dataset = datasets.ImageFolder(root=data_dir+"/valid", transform=transform)
  val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

  test_dataset = datasets.ImageFolder(root=data_dir+"/test", transform=transform)
  test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

  size = 224
  n_epochs = 3

  # build model
  model = build_model(CNN, weight_sharing, ftr_size, num_classes)

  val_loss = []
  val_acc = []
  train_loss = []
  train_acc = []
  test_loss = []
  test_acc = []
  best_val_acc = 0
  best_train = 0
  best_test_acc = 0
  best_test_loss = 2 ** 30
  best_conf_mat = [[0 for i in range(3)] for j in range(3)]

  # Move model to device (GPU or CPU)
  model.to(device)

  # Define loss function and optimizer
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)

  total_step = len(train_dataloader)
  total_val = len(val_dataloader)
  last_best = 0
  finished = False
  out_dir = 'Models'
  name = None

  # Train for the specified number of epochs
  for epoch_num in range(n_epochs+1):
    print("Epoch:",epoch_num)
    running_loss = 0.0
    correct = 0
    total = 0
    print("Training...")
    # Go through training dataset and update weights after each batch
    for data, target in tqdm(train_dataloader):
      data, target = data.to(device), target.to(device)
      optimizer.zero_grad()
      outputs, img_with_rect = model(data)
      loss = criterion(outputs, target)
      loss.backward()
      optimizer.step()
      running_loss += loss.item()
      _,pred = torch.max(outputs, dim=1)

      correct += torch.sum(pred==target).item()

    train_acc.append(100*correct / len(train_dataset))
    print("Train acc:",train_acc[-1])
    train_loss.append(running_loss/total_step)
    print("train loss:",train_loss[-1])

    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
      model.eval()
      val_conf_mat = [[0 for i in range(3)] for j in range(3)]
      print("validating...")
      # Go through validation dataset
      for data, target in tqdm(val_dataloader):
        data, target = data.to(device), target.to(device)
        outputs, img_with_rect = model(data)
        loss = criterion(outputs, target)
        running_loss += loss.item()
        _,pred = torch.max(outputs, dim=1)
        for i in range(target.shape[0]):
          val_conf_mat[int(target[i].item())][int(pred[i].item())] += 1

        correct += torch.sum(pred==target).item()

      val_loss.append(running_loss/total_val)
      val_acc.append(100*correct / len(val_dataset))
      # If the model achieves the best validation accuracy,
      # run the model through the test set
      if best_val_acc < val_acc[-1]:
        best_val_acc = val_acc[-1]
        print("val acc:",best_val_acc)
        test_conf_mat = [[0 for i in range(3)] for j in range(3)]
        correct = 0
        running_loss = 0
        print("testing...")

        for data, target in tqdm(test_dataloader):
          data, target = data.to(device), target.to(device)
          outputs, img_with_rect = model(data)
          _,pred = torch.max(outputs, dim=1)
          loss = criterion(outputs, target)
          running_loss += loss.item()

          for i in range(target.shape[0]):
            test_conf_mat[int(target[i].item())][int(pred[i].item())] += 1

          correct += torch.sum(pred==target).item()

        acc = round(100*correct / len(test_dataset),2)
        loss = round(running_loss / len(test_dataloader), 4)
        test_acc.append(acc)
        test_loss.append(loss)
        if acc > best_test_acc:
          best_test_acc = acc
          best_conf_mat = test_conf_mat
          best_test_loss = loss
          print("test acc:",best_test_acc)
          print("test loss:",best_test_loss)
          print("conf_mat:",best_conf_mat)
          torch.save(model.state_dict(), "/content/drive/MyDrive/melanoma_detection_epoch"+str(epoch_num)+".pt")

      model.train()

  # Compute evaluation metrics and log results
  acc, sensitivity, specificity, percision = eval_metrics(best_conf_mat)
  log_results(CNN, ftr_size, weight_sharing, log_file, best_conf_mat, best_test_acc, acc, sensitivity, specificity, percision,
                                                                                                                      best_test_loss)

# **Train and Return a Binary Model**

In [None]:
data_dir = "/content/drive/MyDrive/ISIC_2019_subset"

Train(data_dir = data_dir, CNN='ResNet50', ftr_size=64, weight_sharing=True, out_file='results.txt')