# High-Resolution Interpretable Classification of Artifacts versus Real Variants in Whole Genome Sequencing Data from Archived Tissue  <br/> Domenico & Asimomitis et al.

# Imports

In [None]:
import time
import random
import datetime
import glob
import os
from os.path import join, exists
import numpy as np
import pandas as pd
import cv2
import functools

import tensorflow as tf

from Model_class import *
from ICML_utils import *

In [None]:
DATA_DIR = "../data"
LOG_DIR = "../logs"
if not exists(LOG_DIR):
    os.makedirs(LOG_DIR)

# Helper Functions

In [None]:
def convert_to_three_channels(t, c1=0, c2=2, c3=5):
    t = t.numpy()
    t = cv2.merge([t[:,:,c1], t[:,:,c2], t[:,:,c3]])
    return tf.convert_to_tensor(t)

def _parse_function(proto, c1, c2, c3):
    features = {
        "image/encoded": tf.io.FixedLenFeature((), tf.string),
        "label": tf.io.FixedLenFeature((), tf.int64),
    }
    parsed_features = tf.io.parse_single_example(proto, features)

    x = tf.reshape(tf.io.decode_raw(parsed_features['image/encoded'], tf.uint8), [100, 221, 6])
    
    # extra code to extract 3 specific channels - comment out or in to use or not use
    x = tf.py_function(func=convert_to_three_channels, inp=[x, c1, c2, c3], Tout=tf.uint8)
    # extract label
    y = tf.cast(parsed_features['label'], tf.int64)
    return x, y

def load_dataset(data, batch_size, c1=0, c2=2, c3=5):
    new_parse_function = functools.partial(_parse_function, c1=c1, c2=c2, c3=c3)
    dataset = data.map(new_parse_function, num_parallel_calls=16)
    dataset = dataset.batch(batch_size)#.prefetch(1)  # batch and prefetch

    return iter(dataset)

In [None]:
def tf_to_torch(tensor):
    batch_size = tensor.shape[0]
    output_torch = np.zeros(shape=(batch_size, 3, 100, 221)) # 100x221 to 221x221
    for idx, im in enumerate(tensor):
        image = np.transpose(im.numpy(),(2,0,1))
        image = image.astype(np.float32) / 255.0
        output_torch[idx] = image
    output_torch = torch.from_numpy(output_torch)
    output_torch = output_torch.type(torch.DoubleTensor)
    return output_torch.float()

In [None]:
def write_log(content, filename):
    with open(filename, 'a') as file:
        file.write(content)

# 1. Preprocessing

The first step is passing the variant through DeepVariant's <b>make_examples</b> module. The details for this are located here: https://github.com/google/deepvariant/blob/r1.5/docs/deepvariant-details.md#make_examples

Mutation pileup images are stored in TFRecord format protos and can be manipulated further using Tensorflow. The user can generate pileup images for any variant call set of choice using the <i>--variant_caller vcf_candidate_importer</i> option, they must provide <b>make_examples</b> with a VCF (<i>--proposed_variants</i>), variant BED (<i>--regions</i>), BAM (<i>--reads</i>), and reference FASTA (<i>--ref</i>).

Once generated, the user can read in these files as such:

In [None]:
artifact_dataset = tf.data.TFRecordDataset(join(DATA_DIR, "artifact.tfrecord.gz"), compression_type="GZIP")
real_variant_dataset = tf.data.TFRecordDataset(join(DATA_DIR, "real_variant.tfrecord.gz"), compression_type="GZIP")

then convert to PNGs (for ease of viewing):

In [None]:
artifact_iterator = load_dataset(artifact_dataset, 1)
for _, (data, _) in enumerate(artifact_iterator):
    data = tf_to_torch(data)
    for idx, im in enumerate(data):
        save_tensor_as_png(im, join(DATA_DIR, "artifact.png"))
real_iterator = load_dataset(real_variant_dataset, 1)
for _, (data, _) in enumerate(real_iterator):
    data = tf_to_torch(data)
    for idx, im in enumerate(data):
        save_tensor_as_png(im, join(DATA_DIR, "real_variant.png"))

# 2. Build Model

For training the following needs to be specified below:
- Training/Validation directories should be specified below and filled with example images.
- A dataframe labeling pileup images as real or artifact with a column for filename (excluding extension) and a column for the label.

In [None]:
train_dir = "/work/isabl/home/domenicd/benchmarking/ffpe/cnn/notebook/icml_images/train"
valid_dir = "/work/isabl/home/domenicd/benchmarking/ffpe/cnn/notebook/icml_images/valid"
labels_path = join("/work/isabl/home/domenicd/benchmarking/ffpe/cnn/notebook/icml_images", "labels.tsv")

Hyperparameters can be adjusted and defined below:

In [None]:
# Hyper parameters
num_epochs = 10
num_classes = 2
lr = 0.001
BATCH_SIZE = 16
model_name = "test"
torch.set_num_threads(32)

In [None]:
# Model initialization
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = MyModel(pretrained=False, n_classes=num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.Model.parameters(), lr=lr)

In [None]:
runstart=datetime.datetime.now().strftime("%d_%m_%Y_%H_%M_%S")
start = time.time()
correct = 0
total = 0

labels = pd.read_csv(labels_path, sep="\t").set_index("name")
train_data = glob.glob(os.path.join(train_dir, "*.png"))
valid_data = glob.glob(os.path.join(valid_dir, "*.png"))

TRAIN_SIZE = len(train_data)
VAL_SIZE = len(valid_data)

train_dataset = [(x, labels.loc[x.split("/")[-1].strip(".png"), "label"]) for x in train_data]
valid_dataset = [(x, labels.loc[x.split("/")[-1].strip(".png"), "label"]) for x in valid_data]

# Run training
logfile = join(LOG_DIR, f"{runstart}_training.log")
print(logfile)
print("Executing training")
for epoch in range(num_epochs):
    correct = 0
    total = 0
    running_loss = 0.0

    random.shuffle(train_dataset)
    train_batches = [train_dataset[i:i + BATCH_SIZE] for i in range(0, len(train_dataset), BATCH_SIZE)]

    model.train()
    print("")
    print(f"Epoch {epoch+1}")
    for batch_i, (dataset) in enumerate(train_batches):
        print(f"Batch {batch_i+1}/{round(TRAIN_SIZE/BATCH_SIZE)}", end='\r')
        images, targets = [x[0] for x in dataset],[x[1] for x in dataset]
        data = torch.Tensor(BATCH_SIZE, 3, 221, 221)
        for idx, im in enumerate(images):
            _, images[idx] = read_png_as_tensor(im, 221, 221)
        torch.cat(images, out=data)
        data = data.type(torch.DoubleTensor).float()
        target = torch.from_numpy(np.array(targets)).unsqueeze(-1).long()
        data, target = data.to(device), target.to(device)

        # Clear gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model.forward(data)
        probabilities = torch.softmax(outputs, dim=1)
        loss = criterion(outputs, target.squeeze(dim=1))

        # Backward and optimizer
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        predicted = probabilities.argmax(dim=1)
        correct += sum([val==target.numpy()[:,0][idx] for idx,val in enumerate(predicted)])
        total += len(predicted)
        if (batch_i+1) % 1 == 0:
            accuracy = 100 * correct / total
            write_log('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Training Accuracy: {:.2f}%\n'
                 .format(epoch+1, num_epochs, batch_i+1, round(TRAIN_SIZE/BATCH_SIZE), loss.item(), accuracy), logfile)
            end = time.time()
            write_log(f"Time elapsed: {end-start:.02f} seconds\n", logfile)
            break

    # check validation accuracy
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0

    random.shuffle(valid_dataset)
    valid_batches = [valid_dataset[i:i + BATCH_SIZE] for i in range(0, len(valid_dataset), BATCH_SIZE)]
    for batch_i, (dataset) in enumerate(valid_batches):
        images, targets = [x[0] for x in dataset],[x[1] for x in dataset]
        data = torch.Tensor(BATCH_SIZE, 3, 221, 221)
        for idx, im in enumerate(images):
            _, images[idx] = read_png_as_tensor(im, 221, 221)
        torch.cat(images, out=data)
        data = data.type(torch.DoubleTensor).float()
        target = torch.from_numpy(np.array(targets)).unsqueeze(-1).long()

        data, target = data.to(device), target.to(device)
        
        # Forward pass
        outputs = model.forward(data)
        probabilities = torch.softmax(outputs, dim=1)
        loss = criterion(outputs, target.squeeze(dim=1))
        
        running_loss += loss.item()
        predicted = probabilities.argmax(dim=1)
        correct += sum([val==target.numpy()[:,0][idx] for idx,val in enumerate(predicted)])
        total += len(predicted)

    val_accuracy = 100 * correct / total
    val_loss = running_loss / (batch_i + 1)

    write_log('Epoch [{}/{}] Finished, Validation Accuracy: {:.2f}%\n'
        .format(epoch+1, num_epochs, val_accuracy), logfile)
    end = time.time()
    write_log(f"Time elapsed: {end-start:.02f} seconds\n", logfile)
    write_log("----------------------------------------\n", logfile)

torch.save(model.state_dict(), join(DATA_DIR, f"{model_name}_{runstart}_final.pt"))