# Neural Networks- Final Project 
## Finding MIMO

### Sarah Baalbaki, Tanxin Qiao, Jackie Vo

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
! pip install tape-proteins==0.5
! pip install biopython==1.80
! pip install torchmetrics
! pip install lightning

In [None]:
# data_path = '/content/drive/MyDrive/NN/Project/MIMO_data'
# data_path = '/content/drive/MyDrive/CMU/Project/MIMO_data'
data_path= "/content/drive/MyDrive/MIMO_data"

In [None]:
import torch
import torchmetrics
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.optim.lr_scheduler import ReduceLROnPlateau
from transformers import AutoTokenizer, AutoModelForMaskedLM
from transformers import AutoTokenizer, EsmForMaskedLM
from transformers import AutoTokenizer, EsmModel

## Code to load data 

In [None]:
import pandas as pd
import numpy as np
from tape.datasets import LMDBDataset

GFP_AMINO_ACID_VOCABULARY = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y", "-"]
GFP_ALPHABET = {aa:i for i, aa in enumerate(GFP_AMINO_ACID_VOCABULARY)}

def gfp_dataset_to_df(in_name):
    """Get the GFP dataset as a dataframe"""
    dataset = LMDBDataset(in_name)
    df = pd.DataFrame(list(dataset)[:])
    df["log_fluorescence"] = df.log_fluorescence.apply(lambda x: x[0])
    return df

def get_gfp_dfs(path):
    """Get train, val, and test dataframes for the gfp dataset"""
    train_df = gfp_dataset_to_df(f"{path}/mimo_train.lmdb")
    val_df = gfp_dataset_to_df(f"{path}/mimo_val.lmdb")
    test_df = gfp_dataset_to_df(f"{path}/mimo_test.lmdb")
    return train_df, val_df, test_df

def get_gfp_dfs_shuffled(path):
    """Get train, val, and test dataframes for the gfp dataset"""
    train_df = gfp_dataset_to_df(f"{path}/mimo_train.lmdb")
    val_df = gfp_dataset_to_df(f"{path}/mimo_val.lmdb")
    test_df = gfp_dataset_to_df(f"{path}/mimo_test.lmdb")
    all_data = pd.concat([train_df, val_df, test_df], ignore_index=True)
    all_data_shuffled = all_data.sample(frac=1, random_state=42).reset_index(drop=True)
    train_df, val_test_df = train_test_split(all_data_shuffled, test_size=0.2, random_state=42)
    val_df, test_df = train_test_split(val_test_df, test_size=0.5, random_state=42)
    return train_df, val_df, test_df

## Explore the Data 

In [None]:
import pandas as pd
from matplotlib import pyplot as plt
import numpy as np

train_df, val_df, test_df = get_gfp_dfs(data_path)
print('Training set size:', len(train_df))
print('Validation set size:', len(val_df))
print('Test set size:',len(test_df))
train_df.head()

## Load the Data- dataloaders

In [None]:
from torch.utils.data import Dataset
from lightning import LightningDataModule
from torch.utils.data import DataLoader

class GFPDatamodule(LightningDataModule):
    def __init__(self, root_path, batch_size):
        super().__init__()
        self.batch_size = batch_size
        self.root = root_path
        self.train_df, self.val_df, self.test_df = get_gfp_dfs(root_path)

    def train_dataloader(self):
        sequences = self.train_df['primary'].to_list()
        signal = self.train_df['log_fluorescence'].to_list()
        data = list(zip(sequences, signal))
        return DataLoader(data, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        sequences = self.val_df['primary'].to_list()
        signal = self.val_df['log_fluorescence'].to_list()
        data = list(zip(sequences, signal))
        return DataLoader(data, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        sequences = self.test_df['primary'].to_list()
        data = list(sequences)
        return DataLoader(data, batch_size=self.batch_size, shuffle=False)

    def predict_dataloader(self):
        #, with_target=False
        sequences = self.test_df['primary'].to_list()
        data = list(sequences)
        return DataLoader(data, batch_size=self.batch_size, shuffle=False)

In [None]:
datamodule = GFPDatamodule(data_path, 32)

## 1- Shuffled Dataset 

### GFPRegressor for Shuffled Dataset

In [None]:
from transformers import AutoTokenizer, EsmModel
import torch
import torch.nn as nn
import torchmetrics
from torch.optim import Adam
from lightning import LightningModule

class GFPRegressor(LightningModule):
    def __init__(self):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
        self.encoder = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
        dmodel = self.encoder.config.hidden_size
        self.model = nn.Linear(dmodel, 1)

        # self.criterion = nn.CrossEntropyLoss()
        self.criterion = nn.MSELoss()
        self.val_rmse = torchmetrics.MeanSquaredError()

        self.train_losses = []
        self.val_losses = []

        self.training_step_outputs = []
        self.validation_step_outputs = []

    def forward(self, x):
            inputs = self.tokenizer(x, truncation= True, padding= True, return_tensors="pt")
            inputs = {key: val.to(self.device) for key, val in inputs.items()}
            with torch.no_grad():
                outputs = self.encoder(**inputs)
            pooled_output = outputs.last_hidden_state.mean(dim=1)
            logits = self.model(pooled_output)
            return logits

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        logits_sq= logits.squeeze(-1)
        # print("Training - Logits shape:", logits.shape)
        # print("Training - Target shape:", y.shape)
        # print("logits.squeeze()", logits.squeeze())
        preds= logits_sq.squeeze()
        # print(preds, preds.shape)
        loss = self.criterion(preds, y.float())  # Squeeze to remove extra dimension
        self.log('train_loss', loss, prog_bar= True, on_step=False, on_epoch=True, logger=True,)
        # self.training_step_outputs.append(logits_sq)
        return loss

    def validation_step(self, batch, batch_idx):
      x, y = batch
      logits = self.forward(x)
      logits = logits.squeeze(-1)  # Squeeze the logits tensor to match the shape of y
      # print("Validation - Logits shape:", logits.shape)
      # print("Validation - Target shape:", y.shape)
      preds= logits.squeeze()
      loss = self.criterion(preds, y.float())  # Use unsqueezed logits tensor
      self.log('val_loss', loss, prog_bar=True, on_step=False, on_epoch=True, logger=True,)
      self.val_rmse(preds, y.float())  # Calculate RMSE
      # self.validation_step_outputs.append(logits)
      return loss

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=1e-1)
        return optimizer

### Train the model 

In [None]:
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint

datamodule = GFPDatamodule(data_path, batch_size=32)
model = GFPRegressor()

trainer = Trainer(max_epochs=50)
trainer.fit(model=model, datamodule=datamodule)

### Obtain metrics and plots

In [None]:
trainer.logged_metrics

In [None]:
logged_data = trainer.callback_metrics
print(logged_data.keys())

In [None]:
train_loss = logged_data['train_loss']
print(train_loss.item())

val_loss = logged_data['val_loss']
print(val_loss.item())

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir /content/lightning_logs/version_0

### Evaluate the model 

In [None]:
model.eval()
predictions = trainer.predict(model, datamodule)

In [None]:
preds= []
for p in predictions:
  for el in p:
    preds.append(el.item())

print(len(preds))
print(preds)

In [None]:
from sklearn.metrics import mean_squared_error

rmse = np.sqrt(mean_squared_error(true_vals, preds))
print("Root Mean Squared Error:", rmse)

mse = mean_squared_error(true_vals, preds)
print("Mean Squared Error:", mse)

## 2- Unbalanced Dataset

### Helper functions for saving model checkpoints

In [None]:
def save_chkpt(ckpt_folder, model, epoch, ver):
    """Save a training checkpoint
    Args:
        model_path (str): the path to save the model to
        model (nn.Module): the model to save
        optimizer (torch.optim.Optimizer): the optimizer to save
        epoch (int): the current epoch
        batch (int): the current batch in the epoch
        loss_domain (list of int): a list of the shared domain for val and training
            losses
        val_losses (list of float): a list containing the validation losses
        train_losses (list of float): a list containing the training losses
    """

    path = f'{ckpt_folder}/model_v{ver}_epoch{epoch}'
    torch.save(model.state_dict(), path)

In [None]:
def GetVersion(path):
    all_dirs = os.listdir(path)
    if len(all_dirs) == 0:
      return 1
    all_vers = [int(x.split('_')[1].split('v')[-1]) for x in all_dirs if 'model_v' in x] #has v#
    curr_ver = np.max(all_vers) + 1
    return curr_ver

### Modified data loading

In [None]:
class GFPDataset(Dataset):
    def __init__(self, root_path, part, input_name, label_name):
        super().__init__()
        self.root = root_path
        df = get_gfp_dfs(root_path, part=part)
        self.input = df[input_name].to_numpy()
        self.label = df[label_name].to_numpy()

    def __len__(self):
        return len(self.input)

    def __getitem__(self, index):
        return self.input[index], self.label[index]

### SimpleGFPRegressor

In [None]:
class SimpleGFPModel(nn.Module):
    def __init__(self, in_dim):
        super(SimpleGFPModel, self).__init__()
        self.conv1 = nn.Conv1d(1, 512, kernel_size=3)
        self.bn1 = nn.BatchNorm1d(512)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv1d(512, 256, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu2 = nn.ReLU()
        dim1 = (in_dim - 3)+ 1
        dim2 = (dim1 - 3)+ 1
        self.dense1 = nn.Linear(256*dim2, 1024)
        self.bn3 = nn.BatchNorm1d(1024)
        self.relu3 = nn.ReLU()
        self.fc = nn.Linear(1024, 1)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)
        out = out.reshape(x.shape[0], -1)
        out = self.dense1(out)
        out = self.bn3(out)
        out = self.relu3(out)
        out = self.fc(out)
        out = out.view(-1)
        return out

class GFPRegressor(nn.Module):
    def __init__(self, device):
          super(GFPRegressor, self).__init__()
          self.device = device
          self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
          self.embedder= EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
          self.model = SimpleGFPModel(in_dim=self.embedder.config.hidden_size)
          self.criterion = nn.MSELoss()

    def forward(self, x):
        lengths = torch.tensor([len(i) for i in x]).to(self.device)
        ids= self.tokenizer(x, padding="longest", return_tensors="pt")
        input_ids = torch.tensor(ids['input_ids']).clone().detach().to(self.device)
        attention_mask = torch.tensor(ids['attention_mask']).clone().detach().to(self.device)
        with torch.no_grad():
            embeddings = self.embedder(input_ids=input_ids,
                                   attention_mask=attention_mask).last_hidden_state
        embeddings = embeddings.sum(dim=1)/lengths.view(-1, 1)
        embeddings = torch.unsqueeze(embeddings, 1)
        preds = self.model(embeddings)
        return preds

### Custom train and test functions

In [None]:
def train(model, train_dataloader, val_dataloader, optimizer, num_epochs, device, save_path = None):
  model.to(device)

  all_tloss = []
  all_vloss = []

  ver = GetVersion(save_path)

  for epoch in range(num_epochs):
    model.train()
    train_loss = 0.

    progress_bar = tqdm(total=int(len(train_dataloader)), dynamic_ncols=True, leave=True, position=0, desc=f'Epoch {epoch+1}')

    for inputs, labels in train_dataloader:
        inputs, labels = inputs, labels
        labels = torch.tensor(labels).clone().detach().to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        curr_loss = model.criterion(outputs, labels)
        curr_loss.backward()

        optimizer.step()

        train_loss += curr_loss * len(inputs) /len(train_dataloader.dataset)

        progress_bar.set_postfix(train_loss="{:.04f}".format(train_loss))
        progress_bar.update()

    all_tloss.append(train_loss)

    # Validation phase
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for inputs, labels in val_dataloader:
            labels = torch.tensor(labels).clone().detach().to(device)

            outputs = model(inputs)
            loss = model.criterion(outputs, labels)
            val_loss += loss * len(inputs) / len(val_dataloader.dataset)

            progress_bar.set_postfix(train_loss="{:.04f}".format(train_loss), val_train="{:.04f}".format(val_loss))
            progress_bar.update()

    all_vloss.append(val_loss)
    progress_bar.close()

    if save_path != None:
      save_chkpt(save_path, model, epoch, ver)

  print('Training finished!')

  return all_tloss, all_vloss

In [None]:
def test(model, test_dataloader, optimizer, device):

      model.eval()

      preds = []

      testing_bar = tqdm(total=int(len(test_dataloader)), dynamic_ncols=True, leave=True, position=0, desc=f'Testing')
      test_loss = 0.0

      with torch.no_grad():
          for i, (inputs, labels) in enumerate(test_dataloader):
              labels = torch.tensor(labels).clone().detach().to(device)

              outputs = model(inputs)
              preds.extend(outputs)
              loss = model.criterion(outputs, labels)
              test_loss += loss * len(inputs) / len(test_dataloader.dataset)

              testing_bar.set_postfix(test_loss="{:.04f}".format(test_loss))
              testing_bar.update()
      testing_bar.close()

      print('Finish testing!')

      return preds

### Training

In [None]:
configs = {
    'batch_size': 64,
    'lr': 10e-5
}

In [None]:
train_dataset = GFPDataset(root_path=data_path, part='train', input_name='primary', label_name='log_fluorescence')
val_dataset= GFPDataset(root_path=data_path, part='val', input_name='primary', label_name='log_fluorescence')
test_dataset = GFPDataset(root_path=data_path, part='test', input_name='primary', label_name='log_fluorescence')

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=configs['batch_size'], shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=configs['batch_size'], shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=configs['batch_size'], shuffle=False)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

In [None]:
save_folder = '/content/drive/MyDrive/CMU/Project/Ckpt'

In [None]:
GFPModel = GFPRegressor(device=device)
optimizer = torch.optim.AdamW(GFPModel.parameters(), lr=10e-5)

### Test the model

test_pred = test(GFPModel, test_dataloader, optimizer, device=device)

In [None]:
preds = np.array([x.cpu() for x in test_pred])
sorts = np.argsort(preds)
preds = preds[sorts]

In [None]:
labels = test_dataset.label[sorts]

In [None]:
plt.scatter(labels, preds)
plt.plot(labels, labels, color='r', alpha=0.8)
plt.xlabel('True labels')
plt.ylabel('Predictions')
plt.title('Model2 after 10 epochs with batch norm')

## 3- MIMO Model

### Inspect the Data

In [None]:
train_df, val_df, test_df = get_gfp_dfs(data_path)

In [None]:
print(train_df.shape)
print(val_df.shape)
print(test_df.shape)

In [None]:
train_df_1 = train_df[train_df.num_mutations==1].reset_index(drop=True)
train_df_1.shape

In [None]:
train_df_2 = train_df[train_df.num_mutations==2]
train_df_2.shape

In [None]:
train_df_3 = train_df[train_df.num_mutations==3].reset_index(drop=True)
train_df_3.shape

In [None]:
val_df_1 = val_df[val_df.num_mutations==1].reset_index(drop=True)
val_df_1.shape

In [None]:
val_df_2 = val_df[val_df.num_mutations==2].reset_index(drop=True)
val_df_2.shape

In [None]:
val_df_3 = val_df[val_df.num_mutations==3].reset_index(drop=True)
val_df_3.shape

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Get tokenizer and ENcoder from ESM

In [None]:
esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
esm_encoder = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D").to(device)

### Define Multi-Input Dataset Dataloaders

In [None]:
import torch
from torch.utils.data import Dataset,DataLoader

class MultiInputDataset(Dataset):
    def __init__(self, df_1, tokenizer,encoder,device,df_2=None, df_3=None,test=False):
        if test == False:
            self.input_1 = df_1.primary
            self.input_2 = df_2.primary
            self.input_3 = df_3.primary
            self.output_1 = df_1.log_fluorescence
            self.output_2 = df_2.log_fluorescence
            self.output_3 = df_2.log_fluorescence
            self.tokenizer = tokenizer
            self.encoder = encoder
            self.device = device
            assert len(df_1) == len(df_2) == len(df_3)
        else:
            self.input_1 = df_1.primary
            self.output_1 = df_1.log_fluorescence
        self.tokenizer = tokenizer
        self.encoder = encoder
        self.device = device
        self.test = test


    def __len__(self):
        return len(self.input_1)

    def __getitem__(self, idx):
        if self.test == False:
            inputs = self.tokenizer([self.input_1[idx],self.input_2[idx],self.input_3[idx]], truncation= True, padding= True, return_tensors="pt")
            inputs = {key: val.to(self.device) for key, val in inputs.items()}
            with torch.no_grad():
                    embedding = self.encoder(**inputs)
            pooled_embedding = embedding.last_hidden_state.mean(dim=1)
            # cat_input = [self.input_1[idx],self.input_2[idx],self.input_3[idx]]
            output_1 = torch.tensor(self.output_1[idx], dtype=torch.float32)
            output_2 = torch.tensor(self.output_2[idx], dtype=torch.float32)
            output_3 = torch.tensor(self.output_3[idx], dtype=torch.float32)
            cat_output = torch.cat((output_1.unsqueeze(0), output_2.unsqueeze(0), output_3.unsqueeze(0)), dim=0)
        else:
            inputs = self.tokenizer(self.input_1[idx], truncation= True, padding= True, return_tensors="pt")
            inputs = {key: val.to(self.device) for key, val in inputs.items()}
            with torch.no_grad():
                    embedding = self.encoder(**inputs)
            pooled_embedding = embedding.last_hidden_state.mean(dim=1).repeat(3, 1)
            cat_output = torch.tensor(self.output_1[idx], dtype=torch.float32)
        return pooled_embedding,cat_output

### MIMO Model Architecture

In [None]:
class MultiInputMultiOutputModel(nn.Module):
    def __init__(self,architecture, data_dim, ens_size=3, activation='relu'):
        super(MultiInputMultiOutputModel, self).__init__()

        # Calculate input and output sizes based on ensemble size
        num_logits = 1  # Assuming this is a regression problem

        # Define input layer
        self.input_1 = nn.Linear(data_dim, architecture[0])
        self.input_2 = nn.Linear(data_dim, architecture[0])
        self.input_3 = nn.Linear(data_dim, architecture[0])

        # self.batchnorm = nn.BatchNorm1d(architecture[0])
        # # Define hidden layers
        self.hidden_layers = nn.ModuleList([
            nn.Linear(architecture[i], architecture[i+1])
            for i in range(len(architecture)-1)
        ])


        # Define output layer
        self.output_1 = nn.Linear(architecture[-1], num_logits)
        self.output_2 = nn.Linear(architecture[-1], num_logits)
        self.output_3 = nn.Linear(architecture[-1], num_logits)
        # architecture[-1]
        # Activation function
        self.activation = nn.ReLU() if activation == 'relu' else nn.Sigmoid()

    def forward(self, x):
        # Flatten the input if needed
        # x = x.view(x.size(0), -1)

        # Input layer
        batch_size = x.shape[0]
        data_dim = x.shape[1]

        x_1 = self.input_1(x[:, 0, :] )
        x_2 = self.input_1(x[:, 1, :])
        x_3 = self.input_1(x[:, 2, :])

        x = torch.cat((x_1, x_2, x_3), dim=0)
        # Hidden layers
        for layer in self.hidden_layers:
            x = layer(x)
            x = self.activation(x)

        # Output layer
        output_1 = self.output_1(x[:batch_size,:])
        output_2 = self.output_2(x[batch_size:batch_size*2,:])
        output_3 = self.output_3(x[batch_size*2:,:])


        return torch.cat((output_1, output_2, output_3), dim=1)

In [None]:
model = MultiInputMultiOutputModel(architecture=[8,3], data_dim=320, ens_size=3, activation='relu').to(device)

In [None]:
model

In [None]:
model.load_state_dict(torch.load('best_model_0.pth'))

In [None]:
from torch.utils.tensorboard import SummaryWriter

### Train and Validate model

In [None]:
num_epochs = 10
learning_rate = 0.0001

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=learning_rate)
scheduler = StepLR(optimizer, step_size=3, gamma=0.1)

# Optional: TensorBoard for visualization
# writer = SummaryWriter()  # Uncomment this line to enable TensorBoard
training_losses = []
valid_losses = []
best_valid_loss = np.inf

# Training loop
for epoch in range(num_epochs):
    # Training
    model.train()  # Set model to training mode
    total_loss = 0.0

    n_batch = 10
    random_train_2 = train_df_2.sample(n=877*n_batch,random_state=epoch)
    random_train_3 = train_df_3.sample(n=877*n_batch,random_state=epoch)

    for b in range(n_batch):
        random_batch_2 = random_train_2[877*b:877*(b+1)].reset_index(drop=True)
        random_batch_3 = random_train_3[877*b:877*(b+1)].reset_index(drop=True)
        train_dataset = MultiInputDataset(df_1 = train_df_1, df_2 = random_batch_2, df_3 = random_batch_3, tokenizer = esm_tokenizer, encoder = esm_encoder, device = device)
        train_loader = DataLoader(train_dataset, batch_size=32)

        for inputs, targets in train_loader:
            # Forward pass
            outputs = model(inputs)

            # Compute loss
            # print(outputs)
            loss = criterion(outputs, targets.to(device))

            # Backpropagation and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()


    # Calculate average training loss for the epoch
    average_train_loss = total_loss / (len(train_loader)*n_batch)
    training_losses.append(average_train_loss)
    # Print training loss for the epoch
    # print(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {average_train_loss:.4f}")

    # Optional: Log training loss to TensorBoard
    # writer.add_scalar("Training Loss", average_train_loss, epoch)

    # Evaluation on test set
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():
        total_valid_loss = 0.0
        num_samples = 0

        n_batch = 3
        random_val_2 = val_df_2.sample(n=237*n_batch,random_state=epoch)
        random_val_3 = val_df_3.sample(n=237*n_batch,random_state=epoch)

        for b in range(n_batch):
            random_batch_2 = random_val_2[237*b:237*(b+1)].reset_index(drop=True)
            random_batch_3 = random_val_3[237*b:237*(b+1)].reset_index(drop=True)
            val_dataset = MultiInputDataset(df_1 = val_df_1, df_2 = random_batch_2, df_3 = random_batch_3, tokenizer = esm_tokenizer, encoder = esm_encoder, device = device)
            val_loader = DataLoader(val_dataset, batch_size=32)

            for inputs, targets in val_loader:
                # Forward pass
                outputs = model(inputs)

                # Compute loss
                # print(outputs)
                valid_loss = criterion(outputs, targets.to(device))

                # Accumulate total test loss
                total_valid_loss += valid_loss.item()
                num_samples += inputs.size(0)


        # Calculate average test loss
        average_valid_loss = total_valid_loss / (len(val_loader)*n_batch)
        valid_losses.append(average_valid_loss)

        if average_valid_loss < best_valid_loss:
            best_valid_loss = average_valid_loss
            torch.save(model.state_dict(), 'best_model.pth')  # Save model

        # Print test loss for the epoch
        print(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {average_train_loss:.4f}, Valid Loss: {average_valid_loss:.4f}")
        scheduler.step()

In [None]:
import matplotlib.pyplot as plt

epochs = list(range(1, len(training_losses) + 1))

# plot training and validation losses
plt.figure(figsize=(10, 6))
plt.plot(epochs, training_losses, label='Training Loss', marker='o')
plt.plot(epochs, valid_losses, label='Validation Loss', marker='s')
plt.title('Training and Validation Losses over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.legend()
plt.show()

### Evaluate Model - test data

In [None]:
test_dataset = MultiInputDataset(df_1 = test_df, tokenizer = esm_tokenizer, encoder = esm_encoder, device = device, test=True)

test_loader = DataLoader(test_dataset, batch_size=32)

In [None]:
new_model = MultiInputMultiOutputModel(architecture=[8,3], data_dim=320, ens_size=3, activation='relu').to(device)
new_model.load_state_dict(torch.load('best_model.pth'))
new_model.eval()  # Set model to evaluation mode
with torch.no_grad():
    total_test_loss = 0.0
    num_samples = 0

    for inputs, targets in test_loader:
        # Forward pass
        outputs = model(inputs)
        outputs = torch.mean(outputs, dim=1)
        # Compute loss
        # print(outputs)
        test_loss = criterion(outputs, targets.to(device))

        # Accumulate total test loss
        total_test_loss += test_loss.item()
        num_samples += inputs.size(0)

    # Calculate average test loss
    average_test_loss = total_test_loss / len(test_loader)

    # Print test loss for the epoch
    print(f"Test Loss: {average_test_loss:.4f}")

In [None]:
# num_epochs = 10
# learning_rate = 0.0001

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=learning_rate)
scheduler = StepLR(optimizer, step_size=3, gamma=0.1)

# Optional: TensorBoard for visualization
# writer = SummaryWriter()  # Uncomment this line to enable TensorBoard
# training_losses = []
# valid_losses = []
# best_valid_loss = np.inf

# Training loop
for epoch in range(num_epochs,num_epochs*2):
    # Training
    model.train()  # Set model to training mode
    total_loss = 0.0

    n_batch = 10
    random_train_2 = train_df_2.sample(n=877*n_batch,random_state=epoch)
    random_train_3 = train_df_3.sample(n=877*n_batch,random_state=epoch)

    for b in range(n_batch):
        random_batch_2 = random_train_2[877*b:877*(b+1)].reset_index(drop=True)
        random_batch_3 = random_train_3[877*b:877*(b+1)].reset_index(drop=True)
        train_dataset = MultiInputDataset(df_1 = train_df_1, df_2 = random_batch_2, df_3 = random_batch_3, tokenizer = esm_tokenizer, encoder = esm_encoder, device = device)
        train_loader = DataLoader(train_dataset, batch_size=32)

        for inputs, targets in train_loader:
            # Forward pass
            outputs = model(inputs)

            # Compute loss
            # print(outputs)
            loss = criterion(outputs, targets.to(device))

            # Backpropagation and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()


    # Calculate average training loss for the epoch
    average_train_loss = total_loss / (len(train_loader)*n_batch)
    training_losses.append(average_train_loss)
    # Print training loss for the epoch
    # print(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {average_train_loss:.4f}")

    # Optional: Log training loss to TensorBoard
    # writer.add_scalar("Training Loss", average_train_loss, epoch)

    # Evaluation on test set
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():
        total_valid_loss = 0.0
        num_samples = 0

        n_batch = 3
        random_val_2 = val_df_2.sample(n=237*n_batch,random_state=epoch)
        random_val_3 = val_df_3.sample(n=237*n_batch,random_state=epoch)

        for b in range(n_batch):
            random_batch_2 = random_val_2[237*b:237*(b+1)].reset_index(drop=True)
            random_batch_3 = random_val_3[237*b:237*(b+1)].reset_index(drop=True)
            val_dataset = MultiInputDataset(df_1 = val_df_1, df_2 = random_batch_2, df_3 = random_batch_3, tokenizer = esm_tokenizer, encoder = esm_encoder, device = device)
            val_loader = DataLoader(val_dataset, batch_size=32)

            for inputs, targets in val_loader:
                # Forward pass
                outputs = model(inputs)

                # Compute loss
                # print(outputs)
                valid_loss = criterion(outputs, targets.to(device))

                # Accumulate total test loss
                total_valid_loss += valid_loss.item()
                num_samples += inputs.size(0)


        # Calculate average test loss
        average_valid_loss = total_valid_loss / (len(val_loader)*n_batch)
        valid_losses.append(average_valid_loss)

        if average_valid_loss < best_valid_loss:
            best_valid_loss = average_valid_loss
            torch.save(model.state_dict(), 'best_model.pth')  # Save model

        # Print test loss for the epoch
        print(f"Epoch [{epoch + 1}/{num_epochs*2}], Train Loss: {average_train_loss:.4f}, Valid Loss: {average_valid_loss:.4f}")
        scheduler.step()

In [None]:
new_model = MultiInputMultiOutputModel(architecture=[8,3], data_dim=320, ens_size=3, activation='relu').to(device)
new_model.load_state_dict(torch.load('best_model.pth'))
new_model.eval()  # Set model to evaluation mode
# model.eval()  # Set model to evaluation mode
with torch.no_grad():
    total_test_loss = 0.0
    num_samples = 0

    for inputs, targets in test_loader:
        # Forward pass
        outputs = new_model(inputs)
        outputs = torch.mean(outputs, dim=1)
        # Compute loss
        # print(outputs)
        test_loss = criterion(outputs, targets.to(device))

        # Accumulate total test loss
        total_test_loss += test_loss.item()
        num_samples += inputs.size(0)

    # Calculate average test loss
    average_test_loss = total_test_loss / len(test_loader)

    # Print test loss for the epoch
    print(f"Test Loss: {average_test_loss:.4f}")