In [1]:
import torch
from torch import nn
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import wandb

In [None]:
class ResistNN(pl.LightningModule):
    def __init__(self, num_species, num_antibiotics):
        super(ResistNN, self).__init__()
        # Base feature extraction
        self.feature_extractor = nn.Sequential(
            nn.Conv1D(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1D(in_channels=16, out_channels=16 ,kernel_size=3, stride=2, padding=0),
        )
        # Species prediction head
        self.species_head = nn.Sequential(
            nn.Linear(16, num_species),
            nn.Softmax(dim=1),
        )
        
        #Resistance prediction head
        self.resistance_head = nn.ModuleList([
            nn.Sequential(
                nn.Linear(16 + num_species, 1),
                nn.Sigmoid(),
            ) for _ in range(num_antibiotics)
        ])

    def forward(self, x):
        features = self.feature_extractor(x)
        species_pred = self.species_head(features)

        resistance_preds = [head(torch.cat((features, species_pred), dim=1)) for head in self.resistance_head]
        return species_pred, torch.cat(resistance_preds, dim=1)

    def configure_optimisers(self):
        optimiser = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimiser

    def training_step(self, batch, batch_idx):
        # Unpack batch
        spectra, species_labels, resistance_labels = batch

        # Fwd pass
        species_pred, resistance_preds = forward(spectra)

        # Loss
        species_loss = nn.CrossEntropyLoss()(species_pred, species_labels)
        resistance_loss = nn.BCELoss()(resistance_preds, resistance_labels)
        total_loss = species_loss + resistance_loss

        # Logging
        self.log('train_loss', total_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return total_loss



In [None]:
class ResistDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()

    def train_dataloader(self):
        pass

    def val_dataloader(self):
        pass
        

In [None]:
wandb_logger = WandbLogger(project="amr_driams")

model = ResistNN(num_species=pass, num_antibiotics=pass)

data_module = ResistDataModule()

trainer = pl.Trainer(logger=wandb_logger, max_epochcs=5)
trainer.fit(model, data_module)