In [2]:
import torch
import torch.nn as nn

import trainingset_check as ts
import numpy
import random

from pathlib import Path
import numpy as np
import math
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

cpu


In [3]:
train_unfiltered_json = './data/Markus_trainsets/unfiltered.json'
train_unfiltered_tsv = './data/Markus_trainsets/unfiltered_rest_clu.tsv'

train_tolerant_json = './data/Markus_trainsets/tolerant.json'
train_tolerant_tsv = './data/Markus_trainsets/tolerant_rest_clu.tsv'

train_moderate_json = './data/Markus_trainsets/moderate.json'
train_moderate_tsv = './data/Markus_trainsets/moderate_rest_clu.tsv'

train_strict_json = './data/Markus_trainsets/strict.json'
train_strict_tsv = './data/Markus_trainsets/strict_rest_clu.tsv'

validation_fasta = './data/Markus_trainsets/rr_CheZOD117_test_set.fasta'
test_fasta = './data/Markus_trainsets/TriZOD_test_set.fasta'

h5_file = './data/BMRB_unfiltered_all.h5'

output = "./data/Markus_trainsets/Markus_trainsets_plots/"

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(1024,2,kdim=1024, vdim=1024, batch_first=True)
        
        # Apply the linear layer to the H/N-tensor to expand to 1024
        self.linear_layer = nn.Linear(in_features=2, out_features=1024)
    

    def forward(self, input_Q, input_K, input_V, mask):
        # Apply the linear layer to the H/N-tensor to expand to 1024
        input_V = self.linear_layer(input_V)
 
        self.attn_output, self.attn_output_weights = self.multihead_attn(input_V, input_K, input_Q, need_weights=True)#, attn_mask=mask)
       
        return self.attn_output_weights 
              


In [5]:
from torch.nn.utils.rnn import pad_sequence

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

class MyCollator(object):
    
    def __call__(self, batch):
        # batch is a list of the samples returned by your __get_item__ method in your CustomDataset
        ids, X, Y, Z, Mask = zip(*batch)
        X = pad_sequence(X, batch_first=True)
        Y = pad_sequence(Y, batch_first=True)
        Z = pad_sequence(Z, batch_first=True)
        Mask = pad_sequence(Mask, batch_first=True)
        return (list(ids), X, Y, Z, Mask)


class CustomDataset(torch.utils.data.Dataset):
    
    def __init__(self, samples, first):
       
        self.item = []
        self.seq = []
        self.first = first
        
        #self.linear_layer = nn.Linear(in_features=2, out_features=1024)
        
        for seq, item in samples.items():
            self.seq.append(seq)
            self.item.append(item)
            
            #i = item[0]
            #y = self.linear_layer(torch.tensor(i[2], dtype=torch.float32))
            #l = [(i[0],torch.tensor(i[1], dtype=torch.float32),y,torch.tensor(i[3]), torch.tensor(i[4]))]
            #self.item.append(l)
            
            
        self.data_len = len(self.item)    
        

        
        
    def __len__(self):
        return self.data_len
    
    def __getitem__(self, index):
        
        curr_item = self.item[index]
        
        if self.first:
            i = 0
        else:
            length = len(curr_item)
            i = randint(0, length-1)
            
        item = curr_item[i]
        ids = item[0]
        x = item[1]
        y = item[2]
        z = item[3]
        mask = item[4]
        
        
        
        return (ids, torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(z), torch.tensor(mask))

    
def get_dataloader(customdata, batch_size, first):
    # Create dataloaders with collate function
    my_collator = MyCollator()
    dataset = CustomDataset(customdata, first)
    return torch.utils.data.DataLoader(dataset=dataset,
                                        batch_size=batch_size, 
                                        shuffle=True,
                                        drop_last=False,
                                        collate_fn=my_collator)
    
    
class EarlyStopper():
    def __init__(self, log_dir):
        self.log_dir      = log_dir
        self.checkpoint_p = log_dir / 'checkpoint.pt'
        self.epsilon      = 1e-3 # the minimal difference in improvement a model needs to reach
        self.min_loss     = np.Inf # counter for lowest/best overall-loss
        self.n_worse      = 0 # counter of consecutive non-improving losses
        self.patience     = 500 # number of max. epochs accepted for non-improving loss
        
       
    
    def load_checkpoint(self):
        state = torch.load( self.checkpoint_p)
        model = get_model()
        model.load_state_dict(state['state_dict'])
        print('Loaded model from epoch: {:.1f}'.format(state['epoch']))
        return model, state['epoch']
    
    def save_checkpoint(self, model, epoch, optimizer):
        state = { 
                    'epoch'      : epoch,
                    'state_dict' : model.state_dict(),
                    'optimizer'  : optimizer.state_dict(),
                }
        torch.save( state, self.checkpoint_p )
        return None
    

         
        
    def check_performance(self, model, val_loader, crit, optimizer, epoch, num_epochs):
        current_loss, acc, acc_rndm = testing(model, val_loader, crit, epoch, num_epochs, set_name="VALID")
    
        # if the model improved compared to previously best checkpoint
        if current_loss < (self.min_loss - self.epsilon):
            print('New best model found with loss= {:.3f}'.format(current_loss))  
            self.save_checkpoint( model, epoch, optimizer)
            self.min_loss = current_loss # save new loss as best checkpoint
            self.n_worse  = 0
        else: # if the model did not improve further increase counter
            self.n_worse += 1
            if self.n_worse > self.patience: # if the model did not improve for 'patience' epochs
                print('Stopping due to early stopping after epoch {}!'.format(epoch))
                return True, current_loss, acc, acc_rndm
        return False, current_loss, acc, acc_rndm

    

        
def training(model, trainloader, crit, optimizer, epoch, num_epochs):
    
    model.train() # ensure model is in training mode (dropout, batch norm, ...)
    accuracies = []
    batch = 0
    batch_loss = 0

    for i, (_, X, Y, Z, Mask) in enumerate(trainloader): # iterate over all mini-batches in train
      
        optimizer.zero_grad() # zeroes the gradient buffers of all parameters
        
        X    = X.to(device)
        Y    = Y.to(device)
        Z    = Z.to(device)
        Mask = Mask.to(device)
        
        
        #train
        Yhat = model(X,X,Y, Mask)
        
        #Loss
        loss = get_loss(Z, Yhat, Mask, crit)
        
        
        #torch.autograd.set_detect_anomaly(True)
        
        #backpropagation
        loss.backward() 
        optimizer.step()
        
        
        for name, param in model.named_parameters():
            print(name)
            print(param)

        #sum up loss per batch
        batch += 1
        batch_loss += loss.detach().cpu().numpy()
       
        
        
    #collect loss for plot: mean over batches per epoch
    epoch_loss = batch_loss/batch
    
        
    #end = time.time()
    if epoch % 1 == 0 or epoch == num_epochs:
        out = ('Epoch [{}/{}], TRAIN loss: {:.2f}').format( 
                    epoch, num_epochs, 
                    epoch_loss,
                    #end-start,
                    )
        print(out)
        
    return epoch_loss



def testing( model, testloader, crit, epoch, num_epochs, log_dir=None, set_name=None):
    model.eval() # [ensure model is in training mode (dropout, batch norm, ...)]
    accuracies = []
    accuracies_rndm = []
    

    #start = time.time()
    results = {}
    
    batch = 0
    batch_loss = 0
    
    
    for i, (pdb_ids, X, Y, Z, Mask) in enumerate(testloader):
        
        # IN: [B, L, F] OUT: [B, N, L]
        X    = X.to(device)
        Y    = Y.to(device)
        Z    = Z.to(device)
        Mask = Mask.to(device)
        
        
        
        with torch.no_grad():
            Yhat = model(X, X, Y, Mask)
        
        
        
        loss = get_loss(Z, Yhat, Mask, crit)
        
        
        # sum up loss per batch
        batch += 1
        batch_loss += loss.detach().cpu().numpy()
        
        
        #get max probability of class
        max_out = torch.max(Yhat, 2)
        Yhat_ = max_out.indices
        
        
        # IN: [B, N, L] OUT: [B, L]
        # iterate over every sample in batch
        for sample_idx in range(0,Yhat.shape[0]):
            
            yhat = Yhat_[sample_idx] # get single sample/protein from mini-batch
            y = Z[sample_idx]
            
            print(y)
            print(yhat)
            
            l_y = len(yhat)
            #create index vector as groundtruth
            

            #acc = correct predictions/number of predictions
            comp = (y == yhat)
            true = (comp == True).sum()
            acc = true/l_y
            accuracies.append(acc.item())

            #random index list to compare accuracy against
            res = torch.randperm(l_y).unsqueeze(0)
            comp_rndm = (res == yhat)
            true_rndm = (comp_rndm == True).sum()
            acc_rndm = true_rndm/l_y
            accuracies_rndm.append(acc_rndm.item())
            
             
            if epoch==num_epochs: # store predictions of final checkpoint for writing to log
                pdb_id = pdb_ids[sample_idx]
                results[pdb_id] = (','.join([str(i.detach().cpu().numpy()) for i in y]),
                                   ','.join([str(j.detach().cpu().numpy()) for j in yhat])
                                   ,accuracies[-1])
    
    
    
    
   
    #collect loss for plot: mean over batches per epoch
    epoch_loss = batch_loss/batch
            
    #end = time.time()
    if epoch % 1 == 0 or epoch == num_epochs:
        out = ('Epoch [{}/{}], {} loss: {:.2f}, Accuracy Mean: {:.2f}, Accuracy RNDM: {:.2f}').format( 
                    epoch,num_epochs, set_name,
                    epoch_loss,
                    sum(accuracies)/len(accuracies), sum(accuracies_rndm)/len(accuracies_rndm)
                    #end-start,
                    )
        print(out)
        
    if epoch==num_epochs:
        write_predictions(results, log_dir, set_name)
    
    

  
    return epoch_loss, accuracies, accuracies_rndm





def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def write_predictions(results, log_dir, set_name):
    out_p = log_dir / (set_name + '_log.txt')
    with open(out_p, 'w+') as out_f:
        out_f.write('\n'.join( 
            ">{},y,yhat,acc={:.3f}\n{}\n{}".format(pdb_id, acc, y, yhat) 
                              for pdb_id, (y,yhat,acc) in results.items() ) )
    return None




        
def get_model():
    return MultiHeadAttention().to(device)

        

#Mask residues with -100    
#def mask(true, predict, mask_index):
#    return true[mask_index],predict[mask_index] 


#calcluate Loss 
def get_loss(true, predict, mask_index ,crit):
    #y_masked, yhat_masked = mask(true, predict, mask_index)
    #loss = crit(yhat_masked.float(),y_masked.float())
    
    
    
    p = torch.permute(predict, (0, 2, 1))
    #print(p.shape)
    #test = torch.flatten(predict,start_dim=1, end_dim=2)
    
    #print(true.shape)
    
    loss = crit(p,true)
    
    return loss   


#Get Loss of Epoch 0        
def get_initialized_loss(model, dataloader, crit):
    batch = 0
    batch_loss = 0
   
    
    for i, (pdb_ids, X, Y, Z, Mask) in enumerate(dataloader):
        
        X    = X.to(device)
        Y    = Y.to(device)
        Z    = Z.to(device)
        Mask = Mask.to(device)
        
        
        
        with torch.no_grad():
            Yhat = model(X,X,Y, Mask)
            
        
        loss = get_loss(Z, Yhat, Mask,crit)
            
        
        batch += 1
        batch_loss += loss.detach().cpu().numpy()
        

    epoch_loss = batch_loss/batch
    
    return epoch_loss


cpu


In [6]:
def predict(batch_size,learning_rate, num_epochs,train,valid,test, output, train_type):

    
    root = Path.cwd() 
    # create log directory if it does not exist yet
    log_root = root / "log"
    if not log_root.is_dir():
        print("Creating new log-directory: {}".format(log_root))
        log_root.mkdir()

    log_dir = log_root

    model = get_model()

    early_stopper = EarlyStopper(log_dir)

    n_free_paras = count_parameters(model)
    print('Number of free parameters: {}'.format(n_free_paras))

    crit = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, amsgrad=True)


    #Loss
    train_loss = []
    valid_loss = []
    
    
    #Data
    train_loader = get_dataloader(train,batch_size=batch_size,first=True)
    val_loader = get_dataloader(valid,batch_size=batch_size,first=True)
    test_loader = get_dataloader(test,batch_size=batch_size,first=True)

    
    #Initial Loss of Epoch 0
    t_epoch_loss = get_initialized_loss(model, train_loader,crit)
    v_epoch_loss = get_initialized_loss(model, val_loader, crit)

    train_loss.append(math.sqrt(t_epoch_loss))
    valid_loss.append(math.sqrt(v_epoch_loss))
    

    
    print(train_type)

    #start = time.time()
    for epoch in range(num_epochs): # for each epoch: train & valid
        stop, v_epoch_loss, acc, acc_rndm = early_stopper.check_performance(model, val_loader, crit, optimizer, epoch, num_epochs)
        if stop: # if early stopping criterion was reached
            break
        t_epoch_loss = training(model, train_loader, crit, optimizer, epoch, num_epochs) 
        
        # Collect loss per epoch
        train_loss.append(math.sqrt(t_epoch_loss))
        valid_loss.append(math.sqrt(v_epoch_loss))

     


    # load the model weights of the best checkpoint
    model = early_stopper.load_checkpoint()[0]
    #end = time.time()
    #print('Total training time: {}[m]'.format((end-start)/60))
    print('Running final evaluation on the best checkpoint.')
    _, acc, acc_rndm = testing(model,test_loader, crit, epoch, num_epochs,log_dir=log_dir, set_name="Test")

    
    

    #Plots
    #Precision/Coverage
    #plot_precision_coverage(precision, coverage, bins, model_choice, num_epochs,learning_rate,hidden_dim,output,train_type,title="Coverage and Accuracy",include_acc=False)
    
    #Accuracy
    #ts.get_boxplot_accuracy(acc, acc_rndm, 'CrossAttention',num_epochs,learning_rate,0,output,train_type)
    
    
    
    #Pearson correlation: predicted vs groundtruth
    #pearson = get_contour_plot(hn[0],hn[1],hn[2],hn[3], model_choice,epoch,learning_rate,hidden_dim,output,train_type)
    
    #Loss
    t = 'Training and Validation Loss ('+train_type+', '+'CrossAttention'+')'
    out_loss = 'loss_'+train_type+'_'+str(num_epochs)+'_'+str(learning_rate)+'.png'
    
    plot_simple(train_loss, valid_loss, 'Training Loss', 'Validation Loss', 'Epochs', 'Loss', 
            t, 10, output, out_loss)
    
    
    return acc, acc_rndm

In [13]:
predict(128, 1e-3, 100,t,t,t, "./", 'strict')

Number of free parameters: 4201472
strict
tensor([ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 19, 20, 21,
        22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
        40, 41, 42, 43, 44, 45, 46, 47, 48, 50, 51, 52, 53, 54, 55, 56, 57, 58,
        59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 77,
        78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91])
tensor([37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37,
        37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37,
        37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37,
        37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37,
        37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37])
Epoch [0/100], VALID loss: 4.52, Accuracy Mean: 0.01, Accuracy RNDM: 0.01
New best model found with loss= 4.522
multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0232, -0.028

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0248, -0.0302,  0.0158,  ..., -0.0263, -0.0339, -0.0149],
        [-0.0144, -0.0091, -0.0277,  ...,  0.0308,  0.0367, -0.0263],
        [-0.0225,  0.0242, -0.0279,  ..., -0.0328, -0.0113, -0.0273],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0026, 0.0026, 0.0026,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0262, -0.0316,  0.0144,  ..., -0.0277, -0.0325, -0.0163],
        [-0.0158, -0.0105, -0.0290,  ...,  0.0295,  0.0381, -0.0277],
        [-0.0239,  0.0228, -0.0292,  ..., -0.0341, -0.0099, -0.0287],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0039, 0.0040, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0270, -0.0324,  0.0137,  ..., -0.0285, -0.0318, -0.0171],
        [-0.0166, -0.0112, -0.0298,  ...,  0.0287,  0.0388, -0.0284],
        [-0.0246,  0.0220, -0.0299,  ..., -0.0349, -0.0092, -0.0294],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0047, 0.0047, 0.0047,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0273, -0.0328,  0.0133,  ..., -0.0289, -0.0314, -0.0174],
        [-0.0170, -0.0116, -0.0302,  ...,  0.0283,  0.0392, -0.0288],
        [-0.0250,  0.0217, -0.0303,  ..., -0.0353, -0.0088, -0.0298],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0051, 0.0051, 0.0050,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0276, -0.0330,  0.0131,  ..., -0.0291, -0.0311, -0.0177],
        [-0.0172, -0.0119, -0.0304,  ...,  0.0281,  0.0395, -0.0291],
        [-0.0253,  0.0214, -0.0305,  ..., -0.0356, -0.0085, -0.0301],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0053, 0.0053, 0.0053,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0278, -0.0332,  0.0129,  ..., -0.0293, -0.0309, -0.0179],
        [-0.0174, -0.0121, -0.0306,  ...,  0.0279,  0.0397, -0.0293],
        [-0.0254,  0.0212, -0.0307,  ..., -0.0357, -0.0083, -0.0303],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0055, 0.0055, 0.0055,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0279, -0.0333,  0.0128,  ..., -0.0294, -0.0308, -0.0180],
        [-0.0175, -0.0122, -0.0307,  ...,  0.0277,  0.0398, -0.0294],
        [-0.0256,  0.0211, -0.0308,  ..., -0.0359, -0.0082, -0.0304],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0056, 0.0056, 0.0056,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0280, -0.0334,  0.0127,  ..., -0.0295, -0.0307, -0.0181],
        [-0.0176, -0.0123, -0.0308,  ...,  0.0277,  0.0399, -0.0295],
        [-0.0257,  0.0210, -0.0309,  ..., -0.0359, -0.0081, -0.0305],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0057, 0.0057, 0.0057,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0281, -0.0335,  0.0126,  ..., -0.0296, -0.0307, -0.0182],
        [-0.0177, -0.0123, -0.0309,  ...,  0.0276,  0.0399, -0.0295],
        [-0.0257,  0.0209, -0.0309,  ..., -0.0360, -0.0081, -0.0305],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0057, 0.0058, 0.0057,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0281, -0.0335,  0.0126,  ..., -0.0296, -0.0306, -0.0182],
        [-0.0177, -0.0124, -0.0309,  ...,  0.0276,  0.0400, -0.0296],
        [-0.0258,  0.0209, -0.0309,  ..., -0.0360, -0.0080, -0.0306],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0058, 0.0058, 0.0058,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0281, -0.0336,  0.0125,  ..., -0.0297, -0.0306, -0.0182],
        [-0.0178, -0.0124, -0.0309,  ...,  0.0275,  0.0400, -0.0296],
        [-0.0258,  0.0209, -0.0310,  ..., -0.0361, -0.0080, -0.0306],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0058, 0.0058, 0.0058,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0281, -0.0336,  0.0125,  ..., -0.0297, -0.0306, -0.0182],
        [-0.0178, -0.0124, -0.0310,  ...,  0.0275,  0.0400, -0.0296],
        [-0.0258,  0.0209, -0.0310,  ..., -0.0361, -0.0080, -0.0306],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0058, 0.0059, 0.0058,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0282, -0.0336,  0.0125,  ..., -0.0297, -0.0306, -0.0183],
        [-0.0178, -0.0124, -0.0310,  ...,  0.0275,  0.0400, -0.0296],
        [-0.0258,  0.0208, -0.0310,  ..., -0.0361, -0.0080, -0.0306],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0058, 0.0059, 0.0058,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0282, -0.0336,  0.0125,  ..., -0.0297, -0.0306, -0.0183],
        [-0.0178, -0.0124, -0.0310,  ...,  0.0275,  0.0400, -0.0296],
        [-0.0258,  0.0208, -0.0310,  ..., -0.0361, -0.0080, -0.0306],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0058, 0.0059, 0.0058,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0282, -0.0336,  0.0125,  ..., -0.0297, -0.0306, -0.0183],
        [-0.0178, -0.0124, -0.0310,  ...,  0.0275,  0.0400, -0.0296],
        [-0.0258,  0.0208, -0.0310,  ..., -0.0361, -0.0080, -0.0306],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0058, 0.0059, 0.0058,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0282, -0.0336,  0.0125,  ..., -0.0297, -0.0306, -0.0183],
        [-0.0178, -0.0125, -0.0310,  ...,  0.0275,  0.0401, -0.0296],
        [-0.0258,  0.0208, -0.0310,  ..., -0.0361, -0.0080, -0.0306],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0059, 0.0059, 0.0058,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0282, -0.0336,  0.0125,  ..., -0.0297, -0.0306, -0.0183],
        [-0.0178, -0.0125, -0.0310,  ...,  0.0275,  0.0401, -0.0296],
        [-0.0258,  0.0208, -0.0310,  ..., -0.0361, -0.0080, -0.0306],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0059, 0.0059, 0.0058,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0282, -0.0336,  0.0125,  ..., -0.0297, -0.0306, -0.0183],
        [-0.0178, -0.0125, -0.0310,  ...,  0.0275,  0.0401, -0.0296],
        [-0.0258,  0.0208, -0.0310,  ..., -0.0361, -0.0080, -0.0306],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0059, 0.0059, 0.0058,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0282, -0.0336,  0.0125,  ..., -0.0297, -0.0306, -0.0183],
        [-0.0178, -0.0125, -0.0310,  ...,  0.0275,  0.0401, -0.0296],
        [-0.0258,  0.0208, -0.0310,  ..., -0.0361, -0.0080, -0.0306],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0059, 0.0059, 0.0058,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0282, -0.0336,  0.0125,  ..., -0.0297, -0.0306, -0.0183],
        [-0.0178, -0.0125, -0.0310,  ...,  0.0275,  0.0401, -0.0296],
        [-0.0258,  0.0208, -0.0310,  ..., -0.0361, -0.0080, -0.0306],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0059, 0.0059, 0.0058,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0282, -0.0336,  0.0125,  ..., -0.0297, -0.0306, -0.0183],
        [-0.0178, -0.0125, -0.0310,  ...,  0.0275,  0.0401, -0.0296],
        [-0.0258,  0.0208, -0.0310,  ..., -0.0361, -0.0080, -0.0307],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0059, 0.0059, 0.0058,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0282, -0.0336,  0.0125,  ..., -0.0297, -0.0306, -0.0183],
        [-0.0178, -0.0125, -0.0310,  ...,  0.0275,  0.0401, -0.0296],
        [-0.0258,  0.0208, -0.0310,  ..., -0.0361, -0.0080, -0.0307],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0059, 0.0059, 0.0058,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0282, -0.0336,  0.0125,  ..., -0.0297, -0.0306, -0.0183],
        [-0.0178, -0.0125, -0.0310,  ...,  0.0275,  0.0401, -0.0296],
        [-0.0258,  0.0208, -0.0310,  ..., -0.0361, -0.0080, -0.0307],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0059, 0.0059, 0.0058,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

multihead_attn.in_proj_weight
Parameter containing:
tensor([[-0.0282, -0.0336,  0.0125,  ..., -0.0297, -0.0306, -0.0183],
        [-0.0178, -0.0125, -0.0310,  ...,  0.0275,  0.0401, -0.0296],
        [-0.0258,  0.0208, -0.0310,  ..., -0.0361, -0.0080, -0.0307],
        ...,
        [ 0.0270, -0.0269,  0.0101,  ..., -0.0015, -0.0089, -0.0137],
        [-0.0177, -0.0180,  0.0166,  ...,  0.0332,  0.0120, -0.0101],
        [-0.0289, -0.0010, -0.0093,  ...,  0.0259, -0.0292, -0.0200]],
       requires_grad=True)
multihead_attn.in_proj_bias
Parameter containing:
tensor([0.0059, 0.0059, 0.0058,  ..., 0.0000, 0.0000, 0.0000],
       requires_grad=True)
multihead_attn.out_proj.weight
Parameter containing:
tensor([[ 0.0156,  0.0082,  0.0213,  ...,  0.0021,  0.0058,  0.0290],
        [-0.0228, -0.0299,  0.0219,  ...,  0.0193,  0.0007, -0.0261],
        [ 0.0229,  0.0207,  0.0092,  ..., -0.0210,  0.0132, -0.0222],
        ...,
        [ 0.0267,  0.0283, -0.0106,  ...,  0.0228, -0.0031,  0.0272],
 

NameError: name 'plot_simple' is not defined

In [12]:
t = dict(list(train_strict.items())[:1])
print(t)

{'GSEVSDKRTCVSLTTQRLPCSRIKTYTITEGSLRAVIFITKRGLKVCADPQATWVRDCVRSMDRKSNTRNNMIQTKPTGTQQSTNTAVTLTG': [('15110_1_1_1', array([[-0.1313  , -0.2944  ,  0.1814  , ...,  0.1593  , -0.10913 ,
        -0.2292  ],
       [ 0.224   , -0.2089  ,  0.2485  , ...,  0.1925  , -0.02985 ,
         0.0915  ],
       [ 0.1415  , -0.1478  ,  0.1986  , ...,  0.1984  , -0.005714,
         0.3906  ],
       ...,
       [-0.173   ,  0.0955  , -0.02847 , ..., -0.0364  , -0.2134  ,
         0.03763 ],
       [-0.07544 , -0.02408 ,  0.573   , ..., -0.1737  ,  0.4885  ,
         0.309   ],
       [-0.0171  , -0.06726 ,  0.2898  , ..., -0.04547 ,  0.0632  ,
        -0.1355  ]], dtype=float16), [(8.753, 123.48), (8.192, 120.975), (8.265, 119.016), (8.306, 123.127), (8.208, 121.457), (8.265, 120.461), (7.923, 114.841), (8.389, 119.285), (8.694, 120.165), (7.736, 114.038), (8.534, 121.778), (8.851, 117.186), (8.42, 115.387), (8.159, 124.315), (7.875, 121.553), (8.542, 123.512), (8.306, 112.272), (7.631, 118.92), (7.744,

In [9]:
embeddings = ts.read_h5(h5_file)

In [10]:
def get_loaderset(json_dct_list, IDS, cluster_rep, embeddings, set_type, filtering):
    
    print("\n"+set_type)
    
    shift_list = []
    hn_all = []
    hn_no100_all = []
    index_no_100_all = []
    masks = []
    
    counter_shifts = 0
    counter_shifts_less20 = 0
    seq_len = 0
    seq_len_less20 = 0
    
    all_counter = 0
    filter1 = 0
    filter2 = 0
    filter3 = 0
    
    all_residues = 0
    filter1_residues = 0
    filter2_residues = 0
    filter3_residues = 0
    
    
    invalid_shifts = []
    
    
    
    #get data of cluster representatives
    for i ,e in enumerate(IDS):
        if e in cluster_rep:
            info = json_dct_list[i]
            
            seq = info.get('seq')
            prot_len = len(seq)
            
            
            h = info.get('H')
            n = info.get('N')
            hn = list(zip(h,n))
            
            #collect all non-null shifts as tuple_list and index_dictionary 
            no100, _ ,_ = collect_non_null_shifts(hn)
            len_no100 = len(no100)

            #collect % of invalid shifts/protein_length
            invalid = (prot_len-len_no100)/prot_len 
            # Pre-Filter: Entries without H/N shifts
            if invalid == 1.:
                continue
            
            
            #count residues/proteins
            all_counter += 1
            all_residues += prot_len
            
            if filtering:
            
                # 1.Filtering: more than 30% X in sequence are out
                cut_off_30_x = prot_len*0.3
                x = seq.count('X')

                if x > cut_off_30_x:
                    filter1 += 1
                    filter1_residues += prot_len
                    continue

                # 2.Filtering: less than 90% valid shifts are out (max. 10% null-shifts)
                cut_off_90_valid = prot_len*0.9
                if len_no100 < cut_off_90_valid:
                    filter2 += 1
                    filter2_residues += prot_len
                    continue

                #collect % of invalid shifts/protein_length
                invalid_shifts.append(invalid)

                # 3.Filtering: less than 20AAs are out 
                #Todo: try 30
                if prot_len < 20:
                    seq_len_less20 += 1
                    filter3 += 1
                    filter3_residues += prot_len
                    continue
                else:
                    seq_len += 1



                # 4.Filtering: outlier shifts set to null
                #filter out outliers if H>=5. and H<= 12. & N>=95. and N<=140.
                h_outliers_out = [i if i<=12. and i>=5. else -100.0 for i in h]
                n_outliers_out = [i if i<=140. and i>=95. else -100.0 for i in n]
                hn = list(zip(h_outliers_out,n_outliers_out))

                

            #collect all non-null shifts as tuple_list and index_dictionary 
            no100, index, mask = collect_non_null_shifts(hn)
            
            if len(no100) == 0:
                continue
                  
            #collect data
            shift_list.append(info)   
            hn_no100_all.append(no100)
            hn_all.append(hn)
            index_no_100_all.append(index)
            masks.append(mask)
            
            
    f1 = all_counter-filter1
    f2 = f1-filter2
    f3 = f2-filter3
    
    f1_r = all_residues-filter1_residues
    f2_r = f1_r - filter2_residues
    f3_r = f2_r - filter3_residues
    
    print("all proteins: "+str(all_counter))
    print("filter1: "+str(f1))
    print("filter2: "+str(f2))
    print("filter3: "+str(f3))
    
    print("\nall residues: "+str(all_residues))
    print("filter1: "+str(f1_r))
    print("filter2: "+str(f2_r))
    print("filter3: "+str(f3_r))
    
    
    #flatten list
    flat_hn = [item for sublist in hn_no100_all for item in sublist]
    
    #get mean of hn
    scaler = StandardScaler()
    scaler.fit(flat_hn)
    mean = scaler.mean_
    #print(set_type)
    print(mean)
    
    
    dct = {}
    hn_all2 = []
    
    for i,entry in enumerate(shift_list):
        hn = hn_all[i]
        hn_no100 = hn_no100_all[i]
        index_no100 = index_no_100_all[i]
        mask = masks[i]
        
        
        #print(hn_no100)
        #print(index_no100)
        
        hn_no100_scale = scaler.transform(hn_no100).tolist()
        #hn = [[-100.0,-100.0] if -100.0 in tpl else hn_no100_scale[index_no100.get(i)] for i, tpl in enumerate(hn)]
        
        seq = entry.get('seq')
        ID = entry.get('ID')
        emb = embeddings.get(ID)
        if emb is None:
            continue
    
       
        
        

        dct[seq] = [(ID,emb,hn_no100,index_no100, mask)]
        hn_all2.extend(hn_no100)
        
   
    return dct, hn_all2
    
    

def collect_non_null_shifts(tuple_list):
    no100 = []
    #index = {} 
    mask = []
    index = [] 
    for i,tpl in enumerate(tuple_list):
        if -100.0 not in tpl:
            no100.append(tpl)
            index.append(i)
            mask.append(False)
            #index[counter] = i
            #counter+=1
        else:
            mask.append(True)
            #index.append(-1)
            
            
    return no100, index, mask


In [11]:
#Train Set 
#cluster_rep_unfiltered = ts.read_tsv(train_unfiltered_tsv)
#cluster_rep_tolerant = ts.read_tsv(train_tolerant_tsv)
#cluster_rep_moderate = ts.read_tsv(train_moderate_tsv)
cluster_rep_strict = ts.read_tsv(train_strict_tsv)


#get ids of validation and test set from fasta
validation_ids = ts.get_ids(validation_fasta)
test_ids = ts.get_ids(test_fasta)

#redundancy reduction (Test and Validation Set) included
train_unfiltered_list, train_unfiltered_IDS, validation_list, test_list = ts.read_json(train_unfiltered_json, validation_ids, test_ids)
#train_tolerant_list, train_tolerant_IDS, _, _ = ts.read_json(train_tolerant_json, validation_ids, test_ids)
#train_moderate_list, train_moderate_IDS, _, _ = ts.read_json(train_moderate_json, validation_ids, test_ids)
train_strict_list, train_strict_IDS, _, _ = ts.read_json(train_strict_json, validation_ids, test_ids)

#train_unfiltered,hn_unfiltered = get_loaderset(train_unfiltered_list, train_unfiltered_IDS, cluster_rep_unfiltered, embeddings, "Unfiltered", True)
#train_tolerant,hn_tolerant = get_loaderset(train_tolerant_list, train_tolerant_IDS, cluster_rep_tolerant, embeddings, "Tolerant", True)
#train_moderate,hn_moderate = get_loaderset(train_moderate_list, train_moderate_IDS, cluster_rep_moderate, embeddings, "Moderate", True)
train_strict,hn_strict = get_loaderset(train_strict_list, train_strict_IDS, cluster_rep_strict, embeddings, "Strict", True)


valid,hn_valid = get_loaderset(validation_list, validation_ids, validation_ids, embeddings, "Validation", True)
test,hn_test = get_loaderset(test_list, test_ids, test_ids, embeddings, "Test", True)



Strict
all proteins: 990
filter1: 990
filter2: 482
filter3: 480

all residues: 114555
filter1: 114555
filter2: 52187
filter3: 52151
[  8.27367275 119.46497478]

Validation
all proteins: 107
filter1: 107
filter2: 35
filter3: 35

all residues: 12908
filter1: 12908
filter2: 3933
filter3: 3933
[  8.26348998 119.94756672]

Test
all proteins: 339
filter1: 339
filter2: 184
filter3: 184

all residues: 38106
filter1: 38106
filter2: 20498
filter3: 20498
[  8.25451807 119.49366396]


In [12]:
def get_data(data_dict):


    # Initialize two empty lists
    emb = []
    hn = []

    # Iterate through the dictionary items
    for key, array_3d in data_dict.items():
        # Extract the first and second dimensions from the 3D array
        first_dim_values = [item[1] for item in array_3d]
        second_dim_values = [item[2] for item in array_3d]

        # Append the values to the respective lists
        emb.extend(first_dim_values)
        hn.extend(second_dim_values)

        
    hn_padd = []
    emb_padd = []

    longest_list = len(max(hn, key=len))

    #padding
    for i,sublist in enumerate(hn):
        #hn
        diff = longest_list - len(sublist)
        padd = [[-100, -100]] * diff
        a = sublist + padd
        hn_padd.append(a)

        #emb
        d = [[-100] * 1024] * diff
        e = [*emb[i], *d]
        emb_padd.append(e)

        
    emb_1 = torch.tensor(numpy.array(emb_padd)).type(torch.float32).transpose(-3,-2)
    hn_1 = torch.tensor(numpy.array(hn_padd)).type(torch.float32).transpose(-3,-2)
    ind = torch.all(torch.eq(hn_1, -1.0000e+02), dim=2).transpose(-2,-1)
    
    input_Q = emb_1
    input_K = emb_1
    input_V = hn_1
    
    multihead_attn = nn.MultiheadAttention(1024,512,kdim=1024, vdim=2)
    attn_output, attn_output_weights = multihead_attn(input_Q, input_K, input_V,need_weights=True, key_padding_mask=ind)
    
    max_out = torch.max(attn_output_weights, 2)
    max_out_i = max_out.indices

    #acc = correct predictions/number of predictions
    true_vector = torch.arange(longest_list).unsqueeze(0)
    comp = (true_vector == max_out_i)
    true = (comp == True).sum()
    acc = true/(longest_list*len(attn_output_weights))
    

    #random index list to compare accuracy against
    res = torch.randperm(longest_list).unsqueeze(0)
    comp_rndm = (res == max_out_i)
    true_rndm = (comp_rndm == True).sum()
    acc_rndm = true_rndm/(longest_list*len(attn_output_weights))
    
    
    return acc[0], acc_rndm[0]


In [248]:
acc_unfiltered, acc_rndm_unfiltered = get_data(train_unfiltered)
acc_tolerant, acc_rndm_tolerant = get_data(train_tolerant)
acc_moderate, acc_rndm_moderate = get_data(train_moderate)
acc_strict, acc_rndm_strict = get_data(train_strict)




f = open("./out_attention.txt", "w")
f.write("Unfiltered:")
f.write("Attention Accuracy (1 head): "+str(acc_unfiltered)+"\nRandom Accuracy: "+str(acc_rndm))
f.write("\n\nTolerant:")
f.write("Attention Accuracy (1 head): "+str(acc_tolerant)+"\nRandom Accuracy: "+str(acc_rndm_tolerant))
f.write("\n\nModerate:")
f.write("Attention Accuracy (1 head): "+str(acc_moderate)+"\nRandom Accuracy: "+str(acc_rndm_moderate))
f.write("\n\nStrict:")
f.write("Attention Accuracy (1 head): "+str(acc_strict)+"\nRandom Accuracy: "+str(acc_rndm_strict))
f.close()

RuntimeError: The size of tensor a (457) must match the size of tensor b (102) at non-singleton dimension 1

In [449]:
def plot_simple(x1, x2, label1, label2, x_label, y_label, title, x_range, output, out_name):

    plt.figure(figsize=(14, 6))
    
    l = range(0, len(x1))
    plt.plot(l, x1, label=label1)
    plt.plot(l, x2, label=label2)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    x_labels = np.arange(min(l), max(l)+1, x_range)
    plt.xticks(x_labels)
    plt.ylim(ymin=0)
    
    
    plt.legend()

    plt.savefig(output+out_name)
    plt.show()
    plt.close()

In [923]:
input_ = torch.randn(2, 3, requires_grad=True)
target = torch.empty(1, dtype=torch.long).random_(5)
#output = loss(input, target)

#print(input_.shape)
print(input_)
flatt= torch.flatten(input_, start_dim=0)
print(flatt)




#print(target.shape)
#print(target)


tensor([[1.3237, 0.3321, 1.3199],
        [1.3174, 0.8106, 0.9409]], requires_grad=True)
tensor([1.3237, 0.3321, 1.3199, 1.3174, 0.8106, 0.9409],
       grad_fn=<ReshapeAliasBackward0>)
