In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
from tools.datasets import SemanticSimilarityDataset
from tools.utils import TrainingProgress
import matplotlib.pyplot as plt



In [2]:
ss_bp_train = SemanticSimilarityDataset('../83333/train_data/')
ss_bp_val = SemanticSimilarityDataset('../83333/val_data/')
ss_bp_test = SemanticSimilarityDataset('../83333/test_data/')

Output()

Output()

Output()

In [3]:
for l, d in zip(['train', 'validation', 'test'], [ss_bp_train, ss_bp_val, ss_bp_test]):
    print(l, len(d))

train 1855044
validation 207025
test 207025


In [4]:
train_loader = DataLoader(ss_bp_train, batch_size=256, num_workers=16)
val_loader = DataLoader(ss_bp_val, batch_size=175, num_workers=16, shuffle=True)

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

Using cuda device


In [6]:
class SiameseSimilarityNet(nn.Module):
    
    def __init__(self):
        super(SiameseSimilarityNet, self).__init__()
        self.prot2vec = nn.Sequential(
            nn.Linear(7098, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
        )
        
    def forward(self, p1, p2):
        p1 = self.prot2vec(p1)
        p2 = self.prot2vec(p2)
        batch_size = p1.shape[0]
        dim = p1.shape[1]
        return torch.bmm(p1.reshape(batch_size,1,dim), p2.reshape(batch_size,dim,1)).flatten()
        

In [7]:
model = SiameseSimilarityNet().to(device)

In [8]:
def count_parameters(model):
    temp = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'The model architecture:\n\n', model)
    print(f'\nThe model has {temp:,} trainable parameters')
    
count_parameters(model)

The model architecture:

 SiameseSimilarityNet(
  (prot2vec): Sequential(
    (0): Linear(in_features=7098, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=1024, out_features=512, bias=True)
    (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Linear(in_features=256, out_features=256, bias=True)
    (10): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
  )
)

The model has 7,995,392 trainable parameters


In [9]:
# saving and loading checkpoint mechanisms
def save_checkpoint(save_path, model, optimizer, val_loss):
    if save_path==None:
        return
    save_path = save_path 
    state_dict = {'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'val_loss': val_loss}

    torch.save(state_dict, save_path)

    print(f'Model saved to ==> {save_path}')

def load_checkpoint(model, optimizer):
    save_path = f'siameseNet-batchnorm50.pt'
    state_dict = torch.load(save_path)
    model.load_state_dict(state_dict['model_state_dict'])
    optimizer.load_state_dict(state_dict['optimizer_state_dict'])
    val_loss = state_dict['val_loss']
    print(f'Model loaded from <== {save_path}')
    
    return val_loss

In [None]:
optimizer = optim.Adam(model.parameters(), lr = 0.0006)
num_epochs = 50
criterion = nn.MSELoss()
save_name = f'siameseNet-{num_epochs}_epochs.pt'

best_val_loss = float("Inf")
train_losses = []
val_losses = []
cur_step = 0

with TrainingProgress() as progress:
    epochs = progress.add_task("[green]Epochs", progress_type="epochs", total=num_epochs)
    for epoch in range(num_epochs):
        running_loss = 0.0
        model.train()
        training = progress.add_task(f"[magenta]Training [{epoch}]", total=len(train_loader), progress_type="training")
        validation = progress.add_task(f"[cyan]Validation [{epoch}]", total=len(val_loader), progress_type="validation")
        
        for p1, p2, sim in train_loader:
            # forward
            p1 = p1.to(device)
            p2 = p2.to(device)
            sim = sim.to(device)
            outputs = model(p1, p2)
            loss = criterion(outputs, sim)

            #backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            progress.advance(training)
        avg_train_loss = running_loss / len(train_loader)
        train_losses.append(avg_train_loss)
    
        val_running_loss = 0.0
        with torch.no_grad():
            model.eval()
            for p1, p2, sim in val_loader:
                p1 = p1.to(device)
                p2 = p2.to(device)
                sim = sim.to(device)
                outputs = model(p1, p2)
                loss = criterion(outputs, sim)
                val_running_loss += loss.item()
                progress.advance(validation)
        avg_val_loss = val_running_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        
        
        print('Epoch [{}/{}],Train Loss: {:.4f}, Valid Loss: {:.8f}'
              .format(epoch+1, num_epochs, avg_train_loss, avg_val_loss))
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            save_checkpoint(save_name, model, optimizer, best_val_loss)
        
        # progress.tasks[training].visible = False
        # progress.tasks[validation].visible = False
        progress.advance(epochs)
                
print("Finished Training")  

Output()

Epoch [1/50],Train Loss: 5.2281, Valid Loss: 4.60656838
Model saved to ==> siameseNet-50_epochs.pt
Epoch [2/50],Train Loss: 0.0334, Valid Loss: 0.35512739
Model saved to ==> siameseNet-50_epochs.pt
Epoch [3/50],Train Loss: 0.0333, Valid Loss: 0.04086203
Model saved to ==> siameseNet-50_epochs.pt
Epoch [4/50],Train Loss: 0.0222, Valid Loss: 0.16406885
Epoch [5/50],Train Loss: 0.0195, Valid Loss: 0.02697304
Epoch [6/50],Train Loss: 0.0164, Valid Loss: 0.02594845
Model saved to ==> siameseNet-50_epochs.pt
Epoch [7/50],Train Loss: 0.0161, Valid Loss: 0.02544737
Model saved to ==> siameseNet-50_epochs.pt
Epoch [8/50],Train Loss: 0.0161, Valid Loss: 0.02568234
Epoch [9/50],Train Loss: 0.0161, Valid Loss: 0.02585709
Epoch [10/50],Train Loss: 0.0161, Valid Loss: 0.02682058
Epoch [11/50],Train Loss: 0.0161, Valid Loss: 0.02655192
Epoch [12/50],Train Loss: 0.0161, Valid Loss: 0.02506588
Model saved to ==> siameseNet-50_epochs.pt
Epoch [13/50],Train Loss: 0.0160, Valid Loss: 0.02354558
Model save

In [31]:
#plotting of training and validation loss
fix, axs = plt.subplots(nrows=2, figsize=(10,10), facecolor='white')
for i, ax in enumerate(axs):
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.grid(ls=':', zorder=1)
    if i != 0:
        ax.set_yscale('log')
        ax.set_ylabel('Loss (log scale)')
    ax.plot(train_losses, label='Train Loss')
    ax.plot(val_losses, label="Validation Loss")
    ax.legend(loc='upper right')
plt.savefig('E:/prot2vec/loss-evol.png')
plt.close('all')

In [None]:
x = model.prot2vec(a)

In [None]:
x.cpu().detach().numpy()

In [None]:
embeddings = {'protein':[], 'prot2vec':[]}
for accession, interpro in track(ss_bp_train.interpro_dict.items(), description='Training set'):
    ip = torch.from_numpy(interpro.astype(np.float32)).to(device)
    x = model.prot2vec(ip)
    embeddings['protein'].append(accession)
    embeddings['prot2vec'].append(x.cpu().detach().numpy())


for accession, interpro in track(ss_bp_val.interpro_dict.items(), description='Validation set'):
    ip = torch.from_numpy(interpro.astype(np.float32)).to(device)
    x = model.prot2vec(ip)
    embeddings['protein'].append(accession)
    embeddings['prot2vec'].append(x.cpu().detach().numpy())

df = pd.DataFrame(embeddings)

In [None]:
df

In [None]:
A = torch.from_numpy(a)
B = torch.from_numpy(a)


In [None]:
A.shape

In [None]:
torch.bmm(A.reshape(2,1,3), B.reshape(2,3,1)).flatten()

In [None]:
A.reshape(2,1,3)

In [None]:
torch.tensor(1.0).dtype

In [None]:
len(train_loader)

In [None]:
ss_bp_train.ss_dataset