In [1]:
import torch
import torch.nn as nn
import json
import torch.nn.functional as F

In [2]:
with open('../preprocess/CB513.json', 'r') as f:
    d = json.load(f)

In [23]:
d[]

KeyError: 0

In [3]:
import numpy as np
import torch
import torch.utils.data as data

device="cuda:0"
class S2S(nn.Module):

    def __init__(self, num_features=22,
                       size_embedding=21,
                       bidirectional=False,
                       encoder_hidden_size=250,
                       decoder_hidden_size=250):
        
        self.num_features = num_features
        
        super().__init__()
        self.encoder = nn.LSTM(input_size=size_embedding, 
                               hidden_size=encoder_hidden_size, 
                               num_layers=1, 
                               batch_first=True, 
                               bidirectional=bidirectional)
        
        # Concatenate the prev sequence + embedding
        self.decoder = nn.LSTM(input_size=size_embedding * 2,
                               hidden_size=decoder_hidden_size,
                               num_layers=1,
                               batch_first=True,
                               bidirectional=False)
        
        # Embed the one hot vector 22 into 21 -> 21 b/c easier to concatenate w/ the PSSM row this way
        self.embedding = nn.Embedding(num_features, size_embedding)
        
        self.hidden_to_pssm = nn.Linear(decoder_hidden_size, 21)
        

    def forward(self, x, pssm):
        # Convert to non one-hot for embedding
        x = x.argmax(axis=1)
        
        # Embedding layer
        x = self.embedding(x)
        
        # Don't need the singular hidden state
        _, (h, c) = self.encoder(x)

        # Teacher force pssm during training
        x = torch.cat([x, pssm], axis=2)
        
        out, _ = self.decoder(x, (h, c))
        out = self.hidden_to_pssm(out)

        return out
    
class S2SDataset(data.Dataset):
    def __init__(self, protein_data, ids):

        data_len = len(ids)

        # data_len, 700, 22 one hot
        all_encodings = np.zeros([data_len, 700, 22])
        
        # data_len, 700 x 21 PSSM
        all_pssm = np.zeros([data_len, 700, 21])
        all_lengths = []

        for i, id in enumerate(ids):
            id = str(id)
            if i % 250 == 0:
                print("Loading {0}/{1} proteins".format(i, len(ids)))

            d = protein_data[id]
            protein_length = d["protein_length"]
            all_lengths.append(protein_length)
            
            reshaped = np.array(d["protein_encoding"]).reshape([700, -1])

            all_encodings[i, :] = reshaped[:, 0:22]
            all_pssm[i, :] = reshaped[:, 29:50]

        self.all_encodings = all_encodings.astype(np.uint8)
        self.all_pssm = all_pssm.astype(np.float32)
        self.all_lengths = np.array(all_lengths).astype(np.int32)

        print(len(all_pssm), len(all_pssm), len(all_lengths))

    def __getitem__(self, index):
        """Returns one data pair (image and caption)."""
        encoding = self.all_encodings[index]
        pssm = self.all_pssm[index]
        length = self.all_lengths[index]

        return encoding, pssm, length

    def __len__(self):
        return len(self.all_encodings)


def get_loader(protein_data, ids, batch_size, shuffle, num_workers):
    """Returns torch.utils.data.DataLoader"""

    protein = S2SDataset(protein_data, ids)

    # def collate_fn(data):
    #     return data

    data_loader = torch.utils.data.DataLoader(dataset=protein,
                                              batch_size=batch_size,
                                              shuffle=shuffle,
                                              num_workers=num_workers, )
    # collate_fn=collate_fn)
    return data_loader, len(protein)

In [4]:

len_train = len(d)

ids = np.random.choice(len_train, len_train, replace=False)


val_loader, len_val = get_loader(protein_data=d,
                                 ids=[0, 1, 2],
                                 batch_size=5,
                                 num_workers=1,
                                 shuffle=False)

Loading 0/3 proteins
3 3 3


In [24]:
Y[0]

tensor([[0.2142, 0.1708, 0.0230,  ..., 0.2709, 0.2689, 0.9234],
        [0.0832, 0.0105, 0.0423,  ..., 0.0042, 0.1192, 0.0092],
        [0.3475, 0.0253, 0.9945,  ..., 0.0111, 0.5000, 0.0251],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')

In [5]:
for iter, (X, Y, seq_lens) in enumerate(val_loader):
    print(Y.shape)
    print(X.shape)
    break

torch.Size([3, 700, 21])
torch.Size([3, 700, 22])


In [6]:
for iter, (X, Y, seq_lens) in enumerate(val_loader):
    X = X.permute([0, 2, 1]).long().to(device)
    Y = Y.to(device)
    print(X.shape)
    break

torch.Size([3, 22, 700])


In [7]:
import torch.nn as nn
import torch
import torchvision.models as models
import torch.nn.functional as F
import numpy as np

class S2S(nn.Module):

    def __init__(self, num_features=22,
                       size_embedding=21,
                       bidirectional=False,
                       encoder_hidden_size=250,
                       decoder_hidden_size=250):
        
        self.num_features = num_features
        
        super().__init__()
        self.encoder = nn.LSTM(input_size=size_embedding, 
                               hidden_size=encoder_hidden_size, 
                               num_layers=1, 
                               batch_first=True, 
                               bidirectional=bidirectional)
        
        # Concatenate the prev sequence + embedding
        self.decoder = nn.LSTM(input_size=size_embedding * 2,
                               hidden_size=decoder_hidden_size,
                               num_layers=1,
                               batch_first=True,
                               bidirectional=False)
        
        # Embed the one hot vector 22 into 21 -> 21 b/c easier to concatenate w/ the PSSM row this way
        self.embedding = nn.Embedding(num_features, size_embedding)
        
        self.hidden_to_pssm = nn.Linear(decoder_hidden_size, 21)
        

    def forward(self, x, pssm):
        # Convert to non one-hot for embedding
        x = x.argmax(axis=1)
        
        # Embedding layer
        x = self.embedding(x)
        
        # Don't need the singular hidden state
        _, (h, c) = self.encoder(x)

        # Teacher force pssm during training
        x = torch.cat([x, pssm], axis=2)
        
        out, _ = self.decoder(x, (h, c))
        out = self.hidden_to_pssm(out)

        return out
    
    
    def gen(self, x):
        # Convert to non one-hot for embedding
        x = x.argmax(axis=1)

        # Embedding layer
        x = self.embedding(x)

        # Don't need the singular hidden state
        _, (h, c) = self.encoder(x)

        gen_seq = []
        ht, ct = h, c
        pred_seq = torch.ones_like(x[:, 0:1, :]).to(device)
        
        for t in range(x.shape[1]):
            xt = x[:, t:t+1, :]
            xt = torch.cat([xt, pred_seq], axis=2)

            out, (ht, ct) = s2s.decoder(xt, (ht, ct))
            pred_seq = s2s.hidden_to_pssm(out)
            
            gen_seq.append(pred_seq)

        gen_seq = torch.cat(gen_seq, dim=1)

        return gen_seq

In [8]:
s2s = S2S().to(device)

In [9]:
y = s2s(X, Y)

In [10]:
y = s2s.gen(X)

In [11]:
y.shape

torch.Size([3, 700, 21])

In [12]:
x = X.argmax(axis=1)
x = s2s.embedding(x)
_, (h, c) = s2s.encoder(x)

In [13]:
pred_seq = torch.zeros_like(x[:, 0:1, :]).to(device)

In [14]:
x = X.argmax(axis=1)

# Embedding layer
x = s2s.embedding(x)

# Don't need the singular hidden state
_, (h, c) = s2s.encoder(x)

# Teacher force pssm during training
x = torch.cat([x, Y], axis=2)

out, _ = s2s.decoder(x, (h, c))
out = s2s.hidden_to_pssm(out)

In [15]:
first_seq = s2s.hidden_to_pssm(h).permute([1, 0, 2])

In [16]:
first_seq.shape

torch.Size([3, 1, 21])

In [17]:
seq_holder = torch.zeros_like(Y).cuda()

In [18]:
seq_holder[:, 0:1, :] = first_seq
seq_holder[:, 1:, :] = Y[:, :Y.shape[1]-1, :]

In [19]:
seq_holder.shape

torch.Size([3, 700, 21])

In [20]:
seq_holder.shape

torch.Size([3, 700, 21])

In [21]:
first_seq.shape

torch.Size([3, 1, 21])

In [22]:
all_out = []
ht, ct = h, c

for t in range(x.shape[1]):
    xt = x[:, t:t+1, :]
    xt = torch.cat([xt, pred_seq], axis=2)

    out, (ht, ct) = s2s.decoder(xt, (ht, ct))
    pred_seq = s2s.hidden_to_pssm(out)
    
    all_out.append(pred_seq)

RuntimeError: input.size(-1) must be equal to input_size. Expected 42, got 63

In [None]:
all_out = torch.cat(all_out, dim=1)

In [None]:
all_out.shape

In [None]:
pred_seq.shape

In [None]:
x[:, 0:1, :].shape

In [None]:
def gen(self, x):
    # Convert to non one-hot for embedding
    x = x.argmax(axis=1)

    # Embedding layer
    x = self.embedding(x)

    # Don't need the singular hidden state
    _, (h, c) = self.encoder(x)

    gen_seq = []
    ht, ct = h, c

    for t in range(x.shape[1]):
        xt = x[:, t:t+1, :]
        xt = torch.cat([xt, pred_seq], axis=2)

        out, (ht, ct) = s2s.decoder(xt, (ht, ct))
        pred_seq = s2s.hidden_to_pssm(out)

        gen.append(pred_seq)
        
    gen_seq = torch.cat(gen_seq, dim=1)

    return out

In [None]:
def train(epochs, model, stats_path,
          train_loader, val_loader,
          optimizer, criterion,
          len_train, len_val,
          latest_model_path,
          best_model_path, optim_path, device, early_stop=10):
    
    fmt_string = "Epoch[{0}/{1}], Batch[{2}/{3}], Train Loss: {4}"

    # Load stats if path exists
    if os.path.exists(stats_path):
        with open(stats_path, "rb") as f:
            stats_dict = pkl.load(f)
        print(stats_dict["best_epoch"])
        start_epoch = stats_dict["next_epoch"]
        min_val_loss = stats_dict["valid"][stats_dict["best_epoch"]]["loss"]
        print("Stats exist. Loading from {0}. Starting from Epoch {1}".format(stats_path, start_epoch))
    else:
        min_val_loss = np.inf
        stats_dict = rec_dd()
        start_epoch = 0

        # See loss before training
        val_loss = val(-1, model, val_loader, len_val, criterion, epochs, device, num_features, one_hot_embed)

        # Update statistics dict
        stats_dict["valid"][-1]["acc"] = accs
        stats_dict["valid"][-1]["loss"] = val_loss

    model.train()
    for epoch in range(start_epoch, epochs):
        train_loss = 0.
        all_labels = []
        all_predictions = []

        ts = time.time()
        for iter, (X, Y, seq_lens) in enumerate(train_loader):
            optimizer.zero_grad()

            X = X.permute([0, 2, 1]).long().to(device)
            Y = Y.to(device)

            outputs = model(X, Y)
            
            loss = 0
            for y, t, seq_len in zip(outputs, Y, seq_lens):
                y_cut = y[:seq_len]
                t_cut = t[:seq_len]
              
                loss += criterion(y_cut, t_cut)

            train_loss += loss.item()
            if iter % 10 == 0:
                print(fmt_string.format(epoch, epochs, iter, len(train_loader), loss.item()))

            loss.backward()
            optimizer.step()

        print("\nFinished Epoch {}, Time elapsed: {}, Loss: {}".format(epoch, time.time() - ts,
                                                                       train_loss / len(train_loader)))

        # Avg train loss. Batch losses were un-averaged before when added to train_loss
        stats_dict["train"][epoch]["loss"] = train_loss / len(train_loader)

        # The validation stats after additional epoch
        accs, val_loss = val(epoch, model, val_loader, len_val, criterion, epochs, device)

        # Update statistics dict
        stats_dict["valid"][epoch]["loss"] = val_loss
        stats_dict["next_epoch"] = epoch + 1

        # Save latest model
        torch.save(model, latest_model_path)

        # Save optimizer state dict
        optim_state = {'optimizer': optimizer.state_dict()}
        torch.save(optim_state, optim_path)

        if val_loss <= min_val_loss:
            min_val_loss = val_loss
            # Save best model
            torch.save(model, best_model_path)
            stats_dict["best_epoch"] = epoch
        else:
            early_stop -= 1

        # Save stats
        with open(stats_path, "wb") as f:
            pkl.dump(stats_dict, f)

        if early_stop == 0:
            print('=' * 10, 'Early stopping.', '=' * 10)
            break

        # Set back to train mode
        model.train()

    return stats_dict, model