<a href="https://colab.research.google.com/github/vellamike/training_course/blob/master/snp_caller_exercise.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import random
import numpy as np
NUCLEOTIDES = "ACGT"
random.seed(3)

In [None]:
colours = [
    "#00CC00",
    "#0000CC",
    "#FFB300",
    "#CC0000",
    "gray",
    #"black"
]

# Function to simulate multiple sequence alignments with errors and either no SNP or heterozygous SNP

In [None]:
mutation_labels = {
    "no_SNP": 0,
    "heterozygous_SNP": 1,
    "homozygous_SNP": 2,
}

mutation_type_names = {0: "No mutation",
                 1: "Heterozygous SNP",
                 2: "Homozygous SNP"}

transdict = {"A":0, "C": 1, "G":2, "T":3,}

def simulate_alignments(reference_length=200, 
                        num_alignments = 2000, 
                        coverage = 100, 
                        mutations = mutation_labels.keys(),
                        p_sequencing_error=0.0,
                        p_alignment_error=0.00):
    alignments = []
    mutation_types = []
    
    for i in range(num_alignments):
        snp_index = reference_length // 2 
        if (i % 400 == 0):
            print("Computing alignment ", i)
        reference = [random.choice(NUCLEOTIDES) for _ in range(reference_length)]
        reference_base_at_snp = reference[snp_index]
        snp_base = random.choice([i for i in NUCLEOTIDES if i != reference_base_at_snp])
        mutation_type=random.choice([mutation_labels[m] for m in mutations]) 
        mutation_types.append(mutation_type)
        
        alignment = [reference] #first read is always the reference
        for _ in range(coverage):
            mut_index = snp_index
            new_read = [reference[i] if random.random() > p_sequencing_error else random.choice(NUCLEOTIDES) for i in range(reference_length)]
            if random.random() < p_alignment_error:
                mut_index = snp_index + random.randint(-1,2)
            if mutation_type == 1 and random.random() > 0.5: # heterozygous SNP
                new_read[mut_index] = snp_base            
            if mutation_type == 2: #homozygous SNP
                new_read[mut_index] = snp_base
            if random.random() < p_sequencing_error: #Add errors to SNP region also
                new_read[mut_index] =  random.choice(NUCLEOTIDES)
            alignment.append(new_read)
        alignments.append(alignment)
    alignments = np.array(alignments)
    return np.vectorize(transdict.get)(alignments), mutation_types

In [None]:
# Compute 2000 alignments
alignments, mutation_types = simulate_alignments(num_alignments=2000)

# Visualise the alignments

In [None]:
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
cmap = matplotlib.colors.ListedColormap(colours)

plt.rcParams['figure.dpi'] = 200

def plot_pileup(data):
    plt.imshow(data, cmap=cmap, vmin=0, vmax=4)
    plt.yticks(range(0, 101, 5), ["reference",] + [f"read {n}" for n in range(5, 101, 5)], fontsize=4)
    plt.show()

def plot_delta(data):
    reads_deltas = np.where(
        data[1:] == data[0], 4, data[1:]
    )
    plot_pileup(np.vstack([data[0][None, :], reads_deltas]))

In [None]:
alignment_idx = mutation_types.index(1)
print ("Mutation type: ", mutation_type_names[mutation_types[alignment_idx]])
plot_pileup(alignments[alignment_idx])
plot_delta(alignments[alignment_idx])

## Example one: No mutation - no  errors

In [None]:
alignment_idx = mutation_types.index(0) # take the first example 
print ("Mutation type: ", mutation_type_names[mutation_types[alignment_idx]])
plot_pileup(alignments[alignment_idx])

## Example 2: Heterozygous mutation - no errors

In [None]:
alignment_idx = mutation_types.index(1)
print ("Mutation type: ", mutation_type_names[mutation_types[alignment_idx]])
plot_pileup(alignments[alignment_idx])
plot_delta(alignments[alignment_idx])

## Example 3: Homozygous mutation - no errors

In [None]:
alignment_idx = mutation_types.index(2)
print ("Mutation type: ", mutation_type_names[mutation_types[alignment_idx]])
plot_pileup(alignments[alignment_idx])
plot_delta(alignments[alignment_idx])

# Example 4: No mutations, sequencing and alignment error

In [None]:
# Compute 2000 alignments
alignments, mutation_types = simulate_alignments(num_alignments=2000, p_alignment_error=0.05, p_sequencing_error=0.15)

In [None]:
alignment_idx = mutation_types.index(1)
print ("Mutation type: ", mutation_type_names[mutation_types[alignment_idx]])
plot_pileup(alignments[alignment_idx])
plot_delta(alignments[alignment_idx])

# Exercises

1. Train the example convolutional neural network to distinguish between no SNP and heterozygous SNPs in the absence of errors. 
2. Extend this to distinguish between no SNP, heterozygous SNPs, and homozygous SNPs.
2. How robust is the performance to sequencing error?
3. How robust is the performance to "alignment error"?
4. How well does a model trained on one error rate perform on another? Can this model be made more robust for the "real world"?

# PyTorch Convolutional Neural Network Example

The SNP pileups visualised above are 2D arrays with shape `[101, 200]` that can be input to a neural network.

The output of the neural network can be a 2-dimensional vector containing the probability that the pileup contained no SNP or a heterozygous SNP.

The following demonstrates how to define a neural network with these inputs and outputs and train it on the alignments to predict the probabilities.

Begin by generating some data without the noise and splitting it into non-overlapping train and test sets:

In [None]:
# Compute 2000 alignments, using only no-SNP and heterozygous SNPs,  and no noise to make it easier to begin with!
alignments, mutation_types = simulate_alignments(num_alignments=2000, 
                                                 mutations=["no_SNP","heterozygous_SNP"],
                                                 p_sequencing_error=0.0, 
                                                 p_alignment_error=0.0)

In [None]:
# train on first 80% of data, validate on remaining 20%
mutation_types = np.array(mutation_types)

rng = np.random.default_rng(seed=42) # use a fixed random generator so runs are consistent
idxs = np.arange(alignments.shape[0])
rng.shuffle(idxs)

split_idx = int(alignments.shape[0]*0.8)
train_alignments, valid_alignments = alignments[idxs[:split_idx]], alignments[idxs[split_idx:]]
train_mutation_types, valid_mutation_types = mutation_types[idxs[:split_idx]], mutation_types[idxs[split_idx:]]

Next define the neural network and training loop then train and evaluate the model:

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5) # input channels, output channels, kernel size
        self.conv2 = nn.Conv2d(6, 16, 5) # input channels, output channels, kernel size
        self.pool = nn.MaxPool2d(4, 4)
        self.fc1 = nn.Linear(16 * 5 * 11, 120) 
        self.fc2 = nn.Linear(120, 60)
        self.fc3 = nn.Linear(60, 2)

    def forward(self, x):
        #print(x.shape)
        x = self.pool(F.relu(self.conv1(x)))
        #print(x.shape)
        x = self.pool(F.relu(self.conv2(x)))
        #print(x.shape)
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        #print(x.shape)
        x = F.relu(self.fc1(x))
        #print(x.shape)
        x = F.relu(self.fc2(x))
        #print(x.shape)
        x = self.fc3(x)
        #print(x.shape)
        return x


In [None]:
def train(model, train_alignments, train_mutation_types, valid_alignments, valid_mutation_types, epochs=10, lr=0.001):
    crit = torch.nn.CrossEntropyLoss()
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    
    train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(train_alignments).unsqueeze(1).float(), torch.tensor(train_mutation_types))
    valid_dataset = torch.utils.data.TensorDataset(torch.from_numpy(valid_alignments).unsqueeze(1).float(), torch.tensor(valid_mutation_types))
    
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)
    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=16)
    
    train_losses, valid_losses, valid_accs = [], [], []
    
    for epoch in range(1, epochs + 1):
        # train for 1 epoch
        model = model.train()
        epoch_loss, total = 0.0, 0
        for i, (batch_alignment, batch_mutation_type) in enumerate(train_loader):
            opt.zero_grad()
            out = model(batch_alignment)
            loss = crit(out, batch_mutation_type)
            loss.backward()
            opt.step()
            epoch_loss += loss.item()
            total += 1
        epoch_loss /= total
        
        # compute validation loss and accuracy
        model = model.eval()
        valid_loss, n_correct, n, total = 0.0, 0, 0, 0
        for i, (batch_alignment, batch_mutation_type) in enumerate(valid_loader):
            with torch.no_grad():
                out = model(batch_alignment)
                loss = crit(out, batch_mutation_type)
                
            valid_loss += loss.item()
            total += 1

            predict = torch.nn.functional.softmax(out, dim=1).argmax(dim=1)
            correct = predict == batch_mutation_type
            n += out.shape[0]
            n_correct += correct.sum()
        valid_loss /= total
        accuracy = n_correct / n
        
        train_losses.append(epoch_loss)
        valid_losses.append(valid_loss)
        valid_accs.append(accuracy)
        print(f"epoch={epoch:2d}, train_loss={epoch_loss:.3f}, valid_loss={valid_loss:.3f}, accuracy={accuracy*100:.2f}%")
        
    return train_losses, valid_losses, valid_accs

In [None]:
model = CNN()
n_epochs = 10
lr = 0.001

In [None]:
train_losses, valid_losses, valid_accs = train(model, train_alignments, train_mutation_types, valid_alignments, valid_mutation_types, epochs=n_epochs, lr=lr)

Computing the class probabilities for a single example using the trained model:

In [None]:
def compute_class_probability(alignments_ints):
    t = torch.from_numpy(alignments_ints).float()
    if len(t.shape) == 2:
        t = t.unsqueeze(0)
    t = t.unsqueeze(1)
    with torch.inference_mode():
        scores = model(t)
    return torch.softmax(scores, dim=-1).numpy()

In [None]:
probs = compute_class_probability(alignments[0]).squeeze()

fig, ax = plt.subplots(figsize=(6, 2))
ax.bar(np.arange(2), probs.squeeze())
ax.set_ylim(0, 1)
ax.set_xticks(ticks=[0, 1, 2], labels=mutation_type_names.values())
ax.set_ylabel("Probability")
plt.show()

Compute accuracy over validation examples

In [None]:
class_predictions = compute_class_probability(valid_alignments).argmax(axis=-1)
print(f"Accuracy: {(class_predictions == valid_mutation_types).mean() * 100}%")