# 第10章　実践編5：タンパク質の「言語」の法則を解き明かす

- 清水 秀幸

##### 入力10-1


In [None]:
%matplotlib inline

import os
import re
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

##### 入力10-2


In [None]:
!pip install transformers sentencepiece

##### 入力10-3

In [None]:
!wget https://services.healthtech.dtu.dk/services/DeepLoc-1.0/deeploc_data.fasta -P ./data -q

##### 入力10-4


In [None]:
!ls ./data

##### 入力10-5


In [None]:
!head -n 6 ./data/deeploc_data.fasta

##### 入力10-6


In [None]:
!wc -l ./data/deeploc_data.fasta

##### 入力10-7


In [None]:
!pip install Bio -q
import Bio

##### 入力10-8


In [None]:
def read_fasta(file_path, columns) :
    from Bio.SeqIO.FastaIO import SimpleFastaParser
    with open(file_path) as fasta_file:
        records = []
        for title, sequence in SimpleFastaParser(fasta_file):
            record = []
            title_splits = title.split(None)
            record.append(title_splits[0])
            sequence = "".join(sequence)
            record.append(sequence)
            record.append(len(sequence))
            location_splits = title_splits[1].split("-")
            record.append(location_splits[0])
            record.append(location_splits[1])

            if(len(title_splits) > 2):
                record.append(0)
            else:
                record.append(1)
                
            records.append(record)
    return pd.DataFrame(records, columns = columns)

##### 入力10-9


In [None]:
data = read_fasta('./data/deeploc_data.fasta', columns=['id', 'sequence', 'sequence_length', 'location', 'membrane', 'is_train'])
data.head()

##### 入力10-10


In [None]:
len(data)

##### 入力10-11


In [None]:
data['sequence_length'].describe()

##### 入力10-12


In [None]:
data = data[data['sequence_length'] < 1000]
len(data)

##### 入力10-13


In [None]:
sns.set(style='whitegrid', palette='muted', font_scale=1.2)
ax = sns.histplot(data['sequence_length'].values)
ax.set_xlim(0, 1000)
plt.title(f'sequence length distribution')
plt.grid(True)

##### 入力10-14


In [None]:
data.isnull().values.any()

##### 入力10-15


In [None]:
unique_classes = data.location.unique()
print('Number of classes: ', len(unique_classes))
print(unique_classes)

##### 入力10-16


In [None]:
categories = data.location.astype('category').cat
data['location'] = categories.codes
class_names = categories.categories
num_classes = len(class_names)
print(class_names)

##### 入力10-17


In [None]:
data['location']

##### 入力10-18


In [None]:
df_train = data[data.is_train == 1]
df_train = df_train.drop(['is_train'], axis = 1)
print(df_train.shape[0])
df_train.head()

##### 入力10-19


In [None]:
df_test = data[data.is_train == 0]
df_test = df_test.drop(['is_train'], axis = 1)
print(df_test.shape[0])
df_test.head()

##### 入力10-20


In [None]:
from transformers import T5EncoderModel, T5Tokenizer
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using {}'.format(device))

##### 入力10-21


In [None]:
def get_T5_model():
    model = T5EncoderModel.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc')
    model = model.to(device)
    model = model.eval()
    tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)

    return model, tokenizer

##### 入力10-22


In [None]:
model, tokenizer = get_T5_model()

##### 入力10-23


In [None]:
train_sequences = { i: seq for i, seq in enumerate(df_train['sequence']) }
train_sequences

##### 入力10-24


In [None]:
len(train_sequences)

##### 入力10-25


In [None]:
test_sequences = { i: seq.replace(' ', '') for i, seq in enumerate(df_test['sequence']) }
len(test_sequences)

##### 入力10-26


In [None]:
def get_embeddings( model, tokenizer, seqs, max_residues=4000, max_seq_len=1000, max_batch=100 ):
    results = {'protein_embs' : dict()}
    seq_dict = sorted( seqs.items(), key=lambda kv: len( seqs[kv[0]] ), reverse=True )
    start = time.time()
    batch = list()
    for seq_idx, (pdb_id, seq) in enumerate(seq_dict,1):
        seq = seq
        seq_len = len(seq)
        seq = ' '.join(list(seq))
        batch.append((pdb_id,seq,seq_len))

        n_res_batch = sum([ s_len for _, _, s_len in batch ]) + seq_len
        if len(batch) >= max_batch or n_res_batch>=max_residues or seq_idx==len(seq_dict) or seq_len>max_seq_len:
            pdb_ids, seqs, seq_lens = zip(*batch)
            batch = list()

            token_encoding = tokenizer.batch_encode_plus(seqs, add_special_tokens=True, padding='longest')
            input_ids = torch.tensor(token_encoding['input_ids']).to(device)
            attention_mask = torch.tensor(token_encoding['attention_mask']).to(device)
            try:
                with torch.no_grad():
                    embedding_repr = model(input_ids, attention_mask=attention_mask)
            except RuntimeError:
                print('RuntimeError during embedding for {} (L={})'.format(pdb_id, seq_len))
                continue
            
            for batch_idx, identifier in enumerate(pdb_ids):
                s_len = seq_lens[batch_idx]
                emb = embedding_repr.last_hidden_state[batch_idx,:s_len]
                protein_emb = emb.mean(dim=0)
                results["protein_embs"][identifier] = protein_emb.detach().cpu().numpy().squeeze()


    passed_time=time.time()-start
    avg_time = passed_time/len(results['protein_embs'])
    print('\n##### EMBEDDING COMPLETED ######')
    print('Total number of per-protein embeddings: {}'.format(len(results["protein_embs"])))
    print("Time for generating embeddings: {:.1f}[m] ({:.3f}[s/protein])".format(
        passed_time/60, avg_time ))
    print('\n############# END #############')
    return results

##### 入力10-27


In [None]:
train_embeddings = get_embeddings(model, tokenizer, train_sequences)


##### 入力10-28


In [None]:
print(train_embeddings['protein_embs'][0])
print(train_embeddings['protein_embs'][0].shape)

##### 入力10-29


In [None]:
test_embeddings = get_embeddings(model, tokenizer, test_sequences)

##### 入力10-30


In [None]:
# 目的変数(局在情報)と入力変数(タンパク質の1,024次元の数値ベクトル)をまとめてデータセットに変換
train_embedding_matrices = torch.zeros(len(df_train), 1024)
for i, v in enumerate(train_embeddings['protein_embs'].values()):
    train_embedding_matrices[i] = torch.from_numpy(v.astype(np.float32))
target = torch.tensor(df_train['location'].values, dtype=torch.int64)

train_dataset = torch.utils.data.TensorDataset(train_embedding_matrices, target)

##### 入力10-31


In [None]:
#バッチサイズ
batch_size = 64

# shuffle はデフォルトで False のため，学習データのみ True に指定
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)

##### 入力10-32


In [None]:
# ニューラルネットワークの定義

class Simple_Net(nn.Module):
    # 使用するオブジェクトを定義
    def __init__(self):
        super(Simple_Net, self).__init__()
        self.fc1 = nn.Linear(1024, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 10)

    # 順伝播. 活性化関数を明示して表記している
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = F.softmax(x, dim=1)
        return x

# インスタンス化
simple_net = Simple_Net()

##### 入力10-33


In [None]:
# 損失関数・最適化の設定

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(simple_net.parameters(), lr=0.001)

##### 入力10-34


In [None]:
# 100エポック学習

loss_history = []

for epoch in range(100):
    total_loss = 0
    for x, y in train_loader:

        # 学習ステップ
        optimizer.zero_grad()
        outputs = simple_net(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    loss_history.append(total_loss)
    print(epoch + 1, total_loss)

##### 入力10-35


In [None]:
plt.plot(loss_history)

##### 入力10-36


In [None]:
# 目的変数(局在情報)と入力変数(タンパク質の1,024次元の数値ベクトル)をまとめてデータセットに変換
test_embedding_matrices = torch.zeros(len(df_test), 1024)

for i, v in enumerate(test_embeddings['protein_embs'].values()):
    test_embedding_matrices[i] = torch.from_numpy(v.astype(np.float32))
target = torch.tensor(df_test['location'].values, dtype=torch.int64)

test_dataset = torch.utils.data.TensorDataset(test_embedding_matrices, target)

##### 入力10-37


In [None]:
# バッチサイズ
batch_size = 64

# shuffle はデフォルトで False
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size)

##### 入力10-38


In [None]:
## テストデータにおける正解率を検証

correct = 0
total = 0

with torch.no_grad():
    for x, y in test_loader:
        outputs = simple_net(x)
        _, predicted = torch.max(outputs.data, 1)
        total += y.size(0)
        correct += (predicted == y).sum().item()
print('正解率', int(correct)/total*100)

##### 入力10-39


In [None]:
true_list = []
pred_list = []

with torch.no_grad():
    for x, y in test_loader:
        outputs = simple_net(x)
        _, predicted = torch.max(outputs.data, 1)
        pred_list += predicted.detach().numpy().tolist()
        true_list += y.detach().numpy().tolist()

##### 入力10-40


In [None]:
from sklearn.metrics import confusion_matrix

cm = confusion_matrix(true_list, pred_list)
print(cm)

##### 入力10-41


In [None]:
sns.heatmap(cm)

##### 入力10-42


In [None]:
cm = pd.DataFrame(data=cm, index=class_names.tolist(),
                                              columns=class_names.tolist())
sns.set(rc = {'figure.figsize':(15,8)})
sns.heatmap(cm, square=True, cbar=True, annot=True, cmap='Blues', fmt='d')
plt.yticks(rotation=0)
plt.xlabel("Prediction", fontsize=13, rotation=0)
plt.ylabel("Ground Truth", fontsize=13)
ax.set_ylim(len(cm), 0)
print('アミノ酸配列のみからタンパク質局在を予測した結果')