In [1]:
import os
import pandas as pd
import numpy as np

from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader as TorchDataLoader
from torch.optim import Adam

import wandb

import matplotlib.pyplot as plt

In [None]:
wandb.init(project="torchCNN", entity="rasmusfreund")
WANDB_NOTEBOOK_NAME = "torchCNN.ipynb"
config = wandb.config
config.learning_rate = 0.001
config.num_epochs = 10
config.batch_size = 32

In [2]:
class DataLoader:
    def __init__(self, csv_file_path: str | list[str], data_dir_path: str | list[str], one_hot_encode: bool = True, training_target: str = 'species', num_workers: int = 5):
        """
        Initializes the DataLoader with paths to the csv file and the directory containing the actual data files.

        Parameters:
        csv_file_path: File path (or list of paths) to the CSV file(s) containing the references and labels.
        data_dir_path: Directory path (or list of paths) that contains the actual data files.
        one_hot_encode: True / False statement on whether to one-hot encode species labels
        training_target: Sets label output to either 'species' or 'resistance' depending on target of the model
        num_workers: Integer defining how many CPU cores are available for the data generator
        """
        self.csv_file_path = csv_file_path
        self.num_workers = num_workers
        
        # Control whether to one-hot encode labels
        self.one_hot_encode = one_hot_encode

        # Check target of the model
        if training_target not in ['species', 'resistance']:
            raise ValueError('Argument "training_target" should either be "species" or "resistance".')
        else:
            self.training_target = training_target
        
        # Get data references
        self.data_references = self._load_data_references(csv_file_path, data_dir_path)
        
        # Adjust file paths in data_references to include the correct directory
        self.data_references['file_path'] = self.data_references.apply(self.assign_file_path, axis=1)
        
        # Mapping for one-hot encoding of label-vectorx
        # Case: species
        if self.training_target == 'species':
            self.label_mapping = {label: idx for idx, label in enumerate(self.data_references['species'].unique())}
        # Case: resistance
        else:
            self.antibiotic_columns = [col for col in self.data_references.columns[3:] if '_' not in col]
            self.label_mapping = {antibiotic: idx for idx, antibiotic in enumerate(self.antibiotic_columns)}
        
        self.num_labels = len(self.label_mapping)

    def __len__(self):
        return len(self.data_references)

    
    def __getitem__(self, idx):
        data, label = self.load_data(idx)
        return torch.tensor(data, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)

    
    def _load_data_references(self, csv_file_path, data_dir_path):
        
        data_frames = []
        if isinstance(csv_file_path, list):
            for path, dir_path in zip(csv_file_path, data_dir_path):
                df = pd.read_csv(path, low_memory=False)
                df['source_csv'] = path
                df['data_dir'] = dir_path
                data_frames.append(df)
        else:
            df = pd.read_csv(csv_file_path, low_memory=False)
            df['source_csv'] = csv_file_path
            df['data_dir'] = data_dir_path
            data_frames.append(df)
        
        return pd.concat(data_frames, ignore_index=True)

    
    def assign_file_path(self, row):
        """
        Assigns the correct file path based on the code and directory paths.
        """
        code = row['code']
        data_dir = row['data_dir']
        file_path = f'{data_dir}/{code}.txt'
        if not os.path.exists(file_path):
            raise FileNotFoundError(f'File {code}.txt not found in provided directories.')
        return file_path

    
    def load_data(self, index: int):
        """
        Loads a single data file based on the index provided and returns the contents and label.
        If target of the model is resistances, (R)esistant and (I)ntermediate will both be marked by 1, indicating resistance.

        Parameters:
        index: The index of the data file to load, as referenced in the CSV file.

        Returns:
        data, label pair         
        """

        file_path = self.data_references.iloc[index]['file_path']
        #print(f'Loading data from {file_path}')
        data = self.convert_file_to_floats(file_path)
        if not data:
            print(f'No loaded data from: {file_path}')
        label = None
        
        if self.training_target == 'resistance':
            # Initialise k-hot vector for resistance data
            resistance_vector = [0] * self.num_labels
            for col in self.antibiotic_columns:
                # Get resistance label for each column
                resistance = self.data_references.iloc[index][col]
                if resistance in ['R', 'I']: 
                    col_idx = self.label_mapping[col]
                    resistance_vector[col_idx] = 1
            label = resistance_vector
        else:        
            species = self.data_references.iloc[index]['species']
            label_index = self.label_mapping[species]
            label = np.zeros(self.num_labels)
            label[label_index] = 1 if self.one_hot_encode else label_index
            
        return data, label


    def convert_file_to_floats(self, file_path: str):
        
        data = [] # Container for converted data
        try:
            with open(file_path, 'r') as file:
                next(file) # Skip the header
                for line in file:
                    string_values = line.strip().split(',')[0]
                    data.append(float(string_values))
        except Exception as e:
            print(f'Error processing file {file_path}: {e}')
        return data

    
    def one_hot_encode_label(self, label: str):
        
        one_hot_vector = np.zeros(self.num_labels)
        label_index = self.label_mapping[label]
        one_hot_vector[label_index] = 1
        
        return one_hot_vector


    def get_resistances_from_vector(self, k_hot_vector):

        index_to_antibiotic = {v: k for k, v in self.label_mapping.items()}
        active_indeces = [i for i, val in enumerate(k_hot_vector) if val == 1]
        active_resistances = [index_to_antibiotic[i] for i in active_indeces]

        return active_resistances
    

    def data_generator(self):
        """
        Generator function that yields one data-label pair at a time; parallelised.
        """
        with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
            future_to_idx = {executor.submit(self.load_data, idx): idx for idx in range(len(self.data_references))}
            for future in as_completed(future_to_idx):
                yield future.result()
    


In [None]:
class TorchData(Dataset):
    def __init__(self, X, y):
        self.X = torch.FloatTensor(X)
        self.y = torch.FloatTensor(y)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [None]:
class TorchRidge(nn.Module):
    def __init__(self, num_labels):
        super(TorchRidge, self).__init__()
        self.conv1 = nn.Conv1d(1, 16, kernel_size=3, padding=1)
        self.reduce_conv1 = nn.Conv1d(16, 16, kernel_size=16, stride=16)
        self.fc = nn.Linear(2240, num_labels)

    def forward(self, x):
        x = x.unsqueeze(1)
        x = F.relu(self.conv1(x))
        x = self.reduce_conv1(x)
        x = self.reduce_conv1(x)
        x = x.view(x.size(0), -1) # Flatten tensor
        x = self.fc(x)
        return torch.sigmoid(x)

In [None]:
def multilabel_accuracy(outputs, labels):
    predictions = outputs > 0.5
    correct_pred = (predictions == labels).float()
    accuracy = correct_pred.sum() / (len(predictions) * predictions.size(1))
    return accuracy

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    total_accuracy = 0
    
    for batch_idx, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += multilabel_accuracy(outputs, targets)

    avg_loss = total_loss / len(train_loader)
    avg_accuracy = total_accuracy / len(train_loader)
    return avg_loss, avg_accuracy

def validate(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, targets = data.to(device), target.to(device)
            outputs = model(data)
            loss = criterion(outputs, targets)

            total_loss += loss.item()
            total_accuracy += multilabel_accuracy(outputs, targets)

    avg_loss = total_loss / len(test_loader)
    avg_accuracy = total_accuracy / len(test_loader)
    return avg_loss, avg_accuracy

In [None]:
data_loader = DataLoader(csv_file_path='/faststorage/project/amr_driams/rasmus/data/DRIAMS-A/id/2015_clean_train_val.csv',
                         data_dir_path='/faststorage/project/amr_driams/data/DRIAMS-A/preprocessed_raw/2015',
                         training_target='resistance',     
                         num_workers=4)
total_data_points = len(data_loader)

X, y = [], []
for data, label in tqdm(data_loader.data_generator(), total=total_data_points, desc="Loading Data"):
    X.append(data)
    y.append(label)

X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.20,
    random_state=42
)


In [3]:
data_loader = DataLoader(csv_file_path='/faststorage/project/amr_driams/rasmus/data/DRIAMS-A/id/2016_clean_train_val.csv',
                         data_dir_path='/faststorage/project/amr_driams/data/DRIAMS-A/preprocessed_raw/2016',
                         training_target='resistance',     
                         num_workers=5)
total_data_points = len(data_loader)

X, y = [], []
for _, label in tqdm(data_loader.data_generator(), total=total_data_points, desc="Loading Data"):
    y.append(label)


FileNotFoundError: File 9681cb39-ace9-47f6-a93f-3d3a5e0c9b2d.txt not found in provided directories.

In [None]:
num_classes = len(y[0])
naive_guess = [0 for i in range(num_classes)]
accuracies = []
for ground_truth in y:
    accuracy = sum([1 for gt, ng in zip(ground_truth, naive_guess) if gt == ng]) / num_classes
    accuracies.append(accuracy)

mean_acc = sum(accuracies) / len(accuracies)
print("Mean accuracy for naive guess:", mean_acc)

In [None]:
train_dataset = TorchData(X_train, y_train)
test_dataset = TorchData(X_test, y_test)

train_loader = TorchDataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=1)
test_loader = TorchDataLoader(test_dataset, batch_size=32, shuffle=True, num_workers=1)

In [None]:
model = TorchRidge(num_labels=len(y[0]))
optimizer = Adam(model.parameters(), lr=config.learning_rate)
criterion = nn.BCELoss()
device = 'cpu'

# Training loop
train_loss_hist, train_accuracy_hist = [], []
test_loss_hist, test_accuracy_hist = [], []

for epoch in range(config.num_epochs):
    print(f'Initialising epoch {epoch+1}/{config.num_epochs}')

    train_loss, train_accuracy = train_epoch(model, train_loader, criterion, optimizer, device)
    test_loss, test_accuracy = validate(model, test_loader, criterion, device)

    print(f"Epoch: {epoch+1}\nLoss: {train_loss:.4f}\nAccuracy: {train_accuracy:.4f}\nTest Loss: {test_loss:.4f}\nTest Accuracy: {test_accuracy:.4f}")
    
    train_loss_hist.append(train_loss)
    train_accuracy_hist.append(train_accuracy)
    test_loss_hist.append(test_loss)
    test_accuracy_hist.append(test_accuracy)
    
    wandb.log({
        "epoch": epoch,
        "train_loss": train_loss,
        "train_accuracy": train_accuracy,
        "test_loss": test_loss,
        "test_accuracy": test_accuracy
    })
    wandb.save('model_path')

wandb.finish()