In [1]:
import torch
from torch import nn, optim
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_auc_score
from geoclip import LocationEncoder
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import precision_recall_curve, auc



# Mock data for testing

In [None]:
# We generate mock_data_final for showing the process of generating prototypes, 
# The mock_data_final is used for simulating the real-scenario data that we used for training the SEP model

In [2]:
def calculate_class_weights(y):
    # Count the frequency of each class
    class_counts = torch.bincount(y)
    # Calculate the total number of samples
    total_samples = len(y)
    # Number of classes is the length of class_counts
    num_classes = len(class_counts)
    # Calculate class weights
    class_weights = total_samples / (class_counts * num_classes)
    return class_weights

In [3]:
def calculate_AUCPR(y_test, y_score, n_classes):
    precision = dict()
    recall = dict()
    area_pr = dict()
    precision_mean=dict()
    recall_mean=dict()
    
    for i in range(n_classes):
        precision[i], recall[i], _ = precision_recall_curve(y_test[:, i],
                                                            y_score[:, i])
        area_pr[i]= auc(recall[i], precision[i],)
        precision_mean[i]=precision[i].mean()
        recall_mean[i]=recall[i].mean()
        
    mean_AUCPR = np.mean(list(area_pr.values()))
    mean_precision = np.mean(list(precision_mean.values()))
    mean_recall = np.mean(list(recall_mean.values()))
    mean_f1=2 * (mean_precision * mean_recall) / (mean_precision + mean_recall)
    return mean_precision, mean_recall, mean_f1, mean_AUCPR

In [4]:

# Define the VAE model
class SAVAE(nn.Module):
    def __init__(self):
        super(SAVAE, self).__init__()
        # Define the concatenation part
        self.fc11 = nn.Linear(12, 64)
        self.fc12 = nn.Linear(512, 64)
        self.LeakyReLU = nn.LeakyReLU(0.2)
        # Define the second fully connected layer for mu
        self.fc21 = nn.Linear(128, 16)
        # Define the third fully connected layer for logvar
        self.fc22 = nn.Linear(128, 16)
        # Define the fourth fully connected layer for decoding
        self.fc3 = nn.Linear(16, 12)
        # Define the classifier layer
        self.classifier = nn.Linear(16, 3)
        self.LocEncoder = LocationEncoder()

    # Define the encoder part of VAE
    def encode(self, x_concat):
        h1 = self.LeakyReLU(x_concat)
        return self.fc21(h1), self.fc22(h1)

    # Define the reparameterization trick
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    # Define the decoder part of VAE
    def decode(self, z):
        h3=self.fc3(z)
        return torch.sigmoid(h3)

    def forward(self,x):
        x1=x[:,2:]
        x2=x[:,:2]
        
        loc_embed=self.LocEncoder(x2)
        x_concat=torch.cat((self.fc12(loc_embed), self.fc11(x1)), dim=1)
        mu, logvar = self.encode(x_concat.view(-1, x_concat.shape[1]))
        z = self.reparameterize(mu, logvar)
        # dimension of z is 16
        # return the reconstructed x, the classifier output, mu and logvar
        return self.decode(z), self.classifier(z), mu, logvar, z

# Define the loss function
def loss_function(recon_x, x1, mu, logvar, y, y_pred, class_weights):

    # Calculate the Mean Squared Error loss
    MSE = F.mse_loss(recon_x, x1, reduction='sum')
    # Calculate the KL Divergence
    KLD = beta_value*(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()))
    # Calculate the Cross Entropy loss
    CE = F.cross_entropy(y_pred, y, weight=class_weights)
    # The total loss is the sum of MSE, KLD, and CE
    return MSE, KLD, CE

In [5]:
# Define training function
def train(train_loader, model, device):
    model.train()
    MSE = 0
    KLD = 0
    CE = 0
    train_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data = data.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        # This line is doing several things:
        # 1. It's passing the input data through the model, which is an instance of the VAE class.
        # 2. The model's forward method returns four values: the reconstructed batch (recon_batch), the predicted labels (y_pred), 
        #    and the mean (mu) and log variance (logvar) of the latent variables.
        # 3. These four values are being unpacked into the variables recon_batch, y_pred, mu, and logvar.
        recon_batch, y_pred, mu, logvar, z = model(data)
        MSE_loss, KLD_loss, CE_loss = loss_function(recon_batch, data[:, 2:], mu, logvar, labels, y_pred, class_weights)
        loss_train=MSE_loss+KLD_loss+CE_loss
        loss_train.backward()
        optimizer.step()
        
        MSE += MSE_loss.item()
        KLD += KLD_loss.item()
        CE += CE_loss.item()
        train_loss += loss_train.item()
        
    average_train_MSE = MSE / len(train_loader.dataset)
    average_train_KLD = KLD / len(train_loader.dataset)
    average_train_CE = CE / len(train_loader.dataset)
    average_train_loss = train_loss / len(train_loader.dataset)
    
    return average_train_MSE, average_train_KLD, average_train_CE, average_train_loss
        
# Define testing function
def test(test_loader, model, device):
    MSE = 0
    KLD = 0
    CE = 0
    test_loss = 0
    model.eval()
    all_labels = []  
    all_preds_prob = [] 
    all_preds=[]
    with torch.no_grad():
        for i, (data, labels) in enumerate(test_loader):
            data = data.to(device)
            labels = labels.to(device)
            recon_batch, y_pred, mu, logvar, z = model(data)
            MSE_loss, KLD_loss, CE_loss = loss_function(recon_batch, data[:, 2:], mu, logvar, labels, y_pred, class_weights)
            loss_test=MSE_loss+KLD_loss+CE_loss
            
            MSE += MSE_loss.item()
            KLD += KLD_loss.item()
            CE += CE_loss.item()
            test_loss += loss_test.item()
            
            y_pred_prob=torch.nn.functional.softmax(y_pred, dim=1).cpu().numpy()
            all_labels.extend(labels.cpu().numpy())
            all_preds_prob.extend(y_pred_prob) 
            all_preds.extend(torch.argmax(torch.tensor(y_pred_prob), dim=1).cpu().numpy())
    
    arr_all_labels=np.array(all_labels)
    arr_all_labels_binarize=label_binarize(arr_all_labels, classes=[*range(3)])
    arr_preds_prob=np.stack(all_preds_prob)
    
    average_MSE_test = MSE / len(test_loader.dataset)
    average_KLD_test = KLD / len(test_loader.dataset)
    average_CE_test = CE / len(test_loader.dataset)
    average_test_loss = test_loss / len(test_loader.dataset)


    mean_precision, mean_recall, mean_f1, mean_AUCPR= calculate_AUCPR(arr_all_labels_binarize, arr_preds_prob, n_classes=3)
    mean_AUCROC = roc_auc_score(arr_all_labels, arr_preds_prob, multi_class='ovo', average='weighted')
    return average_MSE_test, average_KLD_test, average_CE_test, average_test_loss,  mean_precision, mean_recall, mean_f1, mean_AUCPR, mean_AUCROC, all_labels, all_preds_prob, all_preds

In [6]:
learning_rate=0.0001
epoch_num=50
batchsize_value=10
# beta_value to weight between the KLD and the MSE LOSS
beta_value=0.0001

In [7]:
mock_data_final=pd.read_csv('./mock_data/mock_data_all.csv')

In [8]:
# Load and preprocess the data
X = mock_data_final.drop('outcome', axis=1)
y = mock_data_final['outcome']
object = MinMaxScaler()
X_scaled=object.fit_transform(X.drop(['lat', 'lon'], axis=1))
X_scaled=pd.DataFrame(X_scaled, columns=X.drop(columns=['lat','lon']).columns)


X_coord=mock_data_final[['lat', 'lon']]
X_scaled_coord = pd.concat([X_coord, X_scaled], axis=1)

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X_scaled_coord, y, test_size=0.2, random_state=42, stratify=y)


In [9]:
X_train_tensor=torch.tensor(X_train.values).float()
X_test_tensor=torch.tensor(X_test.values).float()
y_train_tensor = torch.tensor(y_train.values)
y_test_tensor = torch.tensor(y_test.values)

In [10]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Create TensorDatasets
train_dataset = TensorDataset(X_train_tensor, y_train_tensor.long())  # .long() is used to convert the labels to integer
test_dataset = TensorDataset(X_test_tensor, y_test_tensor.long())
class_weights = calculate_class_weights(y_train_tensor.long())
class_weights=class_weights.to(device)
# Create DataLoaders
batch_size = batchsize_value
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


model = SAVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


In [11]:
train_MSE = []
train_KLD = []
train_CE = []
train_losses=[]

test_MSE = []
test_KLD = []
test_CE = []
test_losses=[]

min_val_loss = float('inf')
patience_counter = 0
patience_limit = 5

epochs = epoch_num
for epoch in range(epoch_num):
    average_train_MSE, average_train_KLD, average_train_CE, average_train_loss=train(train_loader, model, device=device)
    average_MSE_test, average_KLD_test, average_CE_test, average_test_loss, mean_precision, mean_recall, mean_f1, mean_AUCPR, mean_AUCROC, all_labels, all_preds_prob, all_preds=test(test_loader, model, device=device)
    print(f'train: epoch {epoch+1} MSE: {average_train_MSE:.4f}, KLD: {average_train_KLD:.4f}, CE: {average_train_CE:.4f}, loss: {average_train_loss:.4f}')
    print(f'test: epoch {epoch+1} MSE: {average_MSE_test:.4f}, KLD: {average_KLD_test:.4f}, CE: {average_CE_test:.4f}, loss: {average_test_loss:.4f}, AUCPR: {mean_AUCPR:.4f}, AUCROC: {mean_AUCROC:.4f}')
    
    train_MSE.append(average_train_MSE)
    train_KLD.append(average_train_KLD)
    train_CE.append(average_train_CE)
    test_MSE.append(average_MSE_test)
    test_KLD.append(average_KLD_test)
    test_CE.append(average_CE_test)
    
    train_losses.append(average_train_loss)
    test_losses.append(average_test_loss)
    
    if average_CE_test < min_val_loss:
        min_val_loss = average_CE_test
        
        # torch.save(model,'v2_best_savae_model.pth')
        # torch.save(model.state_dict(), 'v2_best_savae_model_dict.pth')
        
        patience_counter = 0
    else:
        patience_counter += 1
        
    if patience_counter >= patience_limit:
        print('Early stopping')
        break
    
    

train: epoch 1 MSE: 2.0998, KLD: 0.0000, CE: 0.1196, loss: 2.2194
test: epoch 1 MSE: 2.0189, KLD: 0.0000, CE: 0.1203, loss: 2.1392, AUCPR: 0.3324, AUCROC: 0.4673
train: epoch 2 MSE: 2.0149, KLD: 0.0003, CE: 0.1210, loss: 2.1361
test: epoch 2 MSE: 1.9176, KLD: 0.0008, CE: 0.1182, loss: 2.0366, AUCPR: 0.3188, AUCROC: 0.4761
train: epoch 3 MSE: 1.8966, KLD: 0.0014, CE: 0.1139, loss: 2.0119
test: epoch 3 MSE: 1.8035, KLD: 0.0019, CE: 0.1098, loss: 1.9152, AUCPR: 0.3681, AUCROC: 0.5130
train: epoch 4 MSE: 1.8145, KLD: 0.0019, CE: 0.1131, loss: 1.9295
test: epoch 4 MSE: 1.8047, KLD: 0.0023, CE: 0.1112, loss: 1.9182, AUCPR: 0.3476, AUCROC: 0.5086
train: epoch 5 MSE: 1.6985, KLD: 0.0025, CE: 0.1097, loss: 1.8106
test: epoch 5 MSE: 1.8665, KLD: 0.0030, CE: 0.1145, loss: 1.9840, AUCPR: 0.3340, AUCROC: 0.4702
train: epoch 6 MSE: 1.5056, KLD: 0.0037, CE: 0.1087, loss: 1.6180
test: epoch 6 MSE: 1.9365, KLD: 0.0037, CE: 0.1176, loss: 2.0578, AUCPR: 0.3132, AUCROC: 0.4714
train: epoch 7 MSE: 1.3241, 

In [12]:
mean_precision, mean_recall, mean_f1, mean_AUCPR, mean_AUCROC

(0.33253013367189826,
 0.49944958781592447,
 0.3992453996397909,
 0.33756555894420065,
 0.497961783008658)