In [1]:
import torch 
from torch import nn, optim,utils
from torchinfo import summary
import numpy as np

def set_seed(random_seed=41):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)

    
class NPDataset(utils.data.Dataset):

    def __init__(self, X):
        self.X = torch.tensor(X,dtype=torch.float)

    def __getitem__(self, item):
        return self.X[item]

    def __len__(self):
        return len(self.X)

    
class CNN(nn.Module):

    def __init__(self, seq_len=44):
        super(CNN, self).__init__()
        self.seq_len=seq_len
        self.model=self.cnn_model()
        
    def cnn_model(self,):
        model = nn.Sequential()
        model.append(nn.Conv1d(in_channels=4, out_channels=100, kernel_size=(5,), padding=2))
        model.append(nn.ReLU())
        model.append(nn.Conv1d(in_channels=100, out_channels=100, kernel_size=(5,), padding=2))
        model.append(nn.ReLU())
        model.append(nn.Dropout(p=0.3))
        model.append(nn.Conv1d(in_channels=100, out_channels=200, kernel_size=(5,), padding=2))
        model.append(nn.ReLU())
        model.append(nn.BatchNorm1d(200))
        model.append(nn.Dropout(p=0.3))
        model.append(nn.Flatten())
        model.append(nn.Linear(200*seq_len,100))
        model.append(nn.ReLU())
        model.append(nn.Dropout(p=0.3))
        model.add_module("linear",nn.Linear(100,1))
        return model
    def forward(self, x):
        x= torch.permute(x,(0,2,1))
        return self.model(x)
    
class ResBlock(nn.Module):
    def __init__(self,dim):
        super(ResBlock,self).__init__()
        self.dim = dim
        model = nn.Sequential()
        model.append(nn.Conv1d(in_channels=self.dim, out_channels=self.dim, kernel_size=(5,), padding=2))
        model.append(nn.ReLU())
        model.append(nn.Conv1d(in_channels=self.dim, out_channels=self.dim, kernel_size=(5,), padding=2))
        model.append(nn.ReLU())
        self.model = model
 
    def forward(self, x):
        output = self.model(x)
        return x+0.3*output
    
class Discriminator(nn.Module):
    def __init__(self,n_sequences=10,seq_length=26):
        super(Discriminator,self).__init__()
        self.seq_len = seq_length
        self.n_sequences = n_sequences
        self.model = nn.Sequential(
            nn.Conv1d(in_channels=4, out_channels=100, kernel_size=(5,), padding=2),
            nn.ReLU(),
            nn.Conv1d(in_channels=100, out_channels=100, kernel_size=(5,), padding=2),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Conv1d(in_channels=100, out_channels=200, kernel_size=(5,), padding=2),
            nn.ReLU(),
            nn.BatchNorm1d(200),
            nn.Dropout(p=0.3),
            nn.Flatten(),
            nn.Linear(200*118,100),
            nn.ReLU(),
            nn.Dropout(p=0.3))
        self.model.add_module("linear",nn.Linear(100,1))
        self.init()
    def init(self,fpath="/Users/john/data/models/sev_pl3_5.pt"):
        save_model = torch.load(fpath)
        modules = {}
        for name, module in save_model.model.named_modules():
            modules[name] = module
        for name,module in self.model.named_modules():
            if isinstance(module,(nn.Conv1d,nn.Linear,nn.BatchNorm1d)):
                module.weight.data.copy_(modules[name].weight.data)
                module.weight.requires_grad=False
                module.bias.data.copy_(modules[name].bias.data)
                module.bias.requires_grad=False
            if isinstance(module,nn.BatchNorm1d):
                module.running_mean.data.copy_(modules[name].running_mean.data)
                module.running_mean.requires_grad=False
                module.running_var.data.copy_(modules[name].running_var.data)
                module.running_var.requires_grad=False
    def forward(self,x):
        x = torch.permute(x,(0,2,1))
        return self.model(x)

class Generator(nn.Module):
    def __init__(self,input_dim, seq_len,dim=128,layers=5):
        super(Generator,self).__init__()
        self.seq_len = seq_len
        self.dim = dim
        self.linear = nn.Linear(input_dim,seq_len*dim)
        self.conv = nn.Sequential()
        for i in range(layers):
            self.conv.append(ResBlock(dim))
        self.conv.append(nn.Conv1d(in_channels=dim, out_channels=4, kernel_size=(1,)))
        self.conv.append(nn.BatchNorm1d(4))
#         self.rnn = nn.LSTM(dim,4,batch_first=True)
#         self.ln = nn.LayerNorm(4)
    def forward(self,x):
        x = self.linear(x).reshape(-1,self.dim,self.seq_len)
        x = self.conv(x).permute((0,2,1))
#         x = torch.relu(self.rnn(x)[0])
#         x = self.ln(x)
        return nn.functional.softmax(x,dim=2)
   

seq_len,dim=118,100
model = Discriminator(seq_length=seq_len)
print(summary(model, input_size=(16, seq_len,4)))

# model(torch.rand((16,seq_len, 4)))
model = Generator(32,seq_len)
print(summary(model, input_size=(16, 32)))
# model(torch.rand((16,32)))
# torch.save(model,"gan.pt")

Layer (type:depth-idx)                   Output Shape              Param #
Discriminator                            [16, 1]                   --
├─Sequential: 1-1                        [16, 1]                   --
│    └─Conv1d: 2-1                       [16, 100, 118]            (2,100)
│    └─ReLU: 2-2                         [16, 100, 118]            --
│    └─Conv1d: 2-3                       [16, 100, 118]            (50,100)
│    └─ReLU: 2-4                         [16, 100, 118]            --
│    └─Dropout: 2-5                      [16, 100, 118]            --
│    └─Conv1d: 2-6                       [16, 200, 118]            (100,200)
│    └─ReLU: 2-7                         [16, 200, 118]            --
│    └─BatchNorm1d: 2-8                  [16, 200, 118]            (400)
│    └─Dropout: 2-9                      [16, 200, 118]            --
│    └─Flatten: 2-10                     [16, 23600]               --
│    └─Linear: 2-11                      [16, 100]              

In [3]:
import tqdm

class MSELoss(nn.Module):
    def __init__(self,target=0):
        super(MSELoss, self).__init__()
        self.target = target

    def forward(self, x):
        # 计算均方误差
        return torch.mean((x-self.target)**2)

batch_size,seq_len,dim=32,118,512
generator = Generator(128,seq_len,dim)
discriminator =Discriminator(seq_len,dim)
criterion = MSELoss(target=1)
optimizer_g = optim.Adam(generator.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-06)
iterations = 10
epoch_loss=[]
for it in tqdm.tqdm(range(iterations)):
    generator.train()
    fake_seed = torch.rand(batch_size,128)
    fake_input = generator(fake_seed)
    outputs = discriminator(fake_input)
    loss = criterion(outputs)
    epoch_loss.append(loss.item())
    optimizer_g.zero_grad()
    loss.backward()
    optimizer_g.step()
    if (it+1) % 100 ==0:
        loss = np.average(epoch_loss)
        epoch_loss=[]
        print(f"Iterations:{it+1} Loss:{loss}")


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:48<00:00,  4.87s/it]


In [4]:
import Levenshtein
import pandas as pd

class OneHotEncoder(object):
    
    def __init__(self,seq_len):
        self.seq_len=seq_len
        self.nuc_d = {'a':[1,0,0,0],
             'c':[0,1,0,0],
             'g':[0,0,1,0],
             't':[0,0,0,1],
             'n':[0,0,0,0]}
    
    def __call__(self,seq):
        seq = seq[:self.seq_len].lower()
        return np.array([self.nuc_d[x] for x in seq])
    
encoder = OneHotEncoder(seq_len)
chars = "ACGT"
original_seq = 'atcccgggtgaggcatcccaccatcctcagtcacagagagacccaatctaccatcagcatcagccagtaaagattaagaaaaacttagggtgaaagaaatttcacctaacacggcgca'
original_seq=original_seq.upper()
fake_seed = torch.rand(64,128)
gen_seqs = generator(fake_seed)
# dist = torch.distributions.Categorical(gen_seqs)
# vectors = dist.sample()
vectors = torch.argmax(gen_seqs,dim=2)
outputs = discriminator(gen_seqs)
seqs,dists,preds = [], [], []
for seq,v in zip(vectors,outputs):
    s = "".join([chars[i] for i in seq])
    d = Levenshtein.distance(original_seq, s)
    if s not in seqs:
        seqs.append(s)
        dists.append(d)
        preds.append(v.detach().numpy()[0])
        
sdf = pd.DataFrame(data={"seq":seqs,"distance":dists,"values":preds})
x_input = torch.tensor(np.array([encoder(v) for v in sdf.seq.values]),dtype=torch.float)
sdf["pred"] = discriminator(x_input)
sdf

Unnamed: 0,seq,distance,values,pred
0,CTGTTTTTTTCTGGGCTAGCGTGTGCTAGCATCCCACGCACTATTC...,73,0.38215,0.639559
1,CTGTTTTTTTCTGGGCTAGCGTGTGCTAGCATCCCACGCACTATTC...,72,0.889458,1.34896
2,CTGTTTTTTTCTAGGCTAGCGTGTGCTAGCATCCCACGCACTATTC...,73,1.151381,0.998149
3,CTGTTTTTTTCTAGGCTAGCGTGTGCTAGCATCCCACGCACTATTC...,72,0.278828,1.101954
4,CTGTTTTTTTCTGGGCTAGCGTGTGCTAGCATCCCACGCACTATTC...,73,0.464135,1.224701
5,CTGTTTTTTTCTAGGCTAGCGTGTGCTAGCGTCCCACGCACTATTC...,73,0.406898,1.214737
6,CTGTTTTTTTCTGGGCTAGCGTGTGCTAGCGTCCCACGCACTATTC...,73,0.89659,0.586801
7,CTGTTTTTTTCTGGGCTAGCGTGTGCTAGCATCCCACGCACTATTC...,73,1.323221,0.683968
8,CTGTTTTTTTCTAGGCTAGCGTGTGCTAGCATCCCACGCACTATTC...,73,0.916669,0.683416
9,CTGTTTTTTTCTAGGCTAGCATGTGCTAGCATCCCACGCACTATTC...,71,0.896507,0.715848
