In [None]:
import torch 
from torch import nn, optim,utils
from torchinfo import summary
import numpy as np
import tqdm

class MSELoss(nn.Module):
    def __init__(self,target=0):
        super(MSELoss, self).__init__()
        self.target = target

    def forward(self, x):
        # 计算均方误差
        return torch.mean((x-self.target)**2)

class OneHotEncoder(object):
    
    def __init__(self,seq_len):
        self.seq_len=seq_len
        self.nuc_d = {'a':[1,0,0,0],
             'c':[0,1,0,0],
             'g':[0,0,1,0],
             't':[0,0,0,1],
             'n':[0,0,0,0]}
    
    def __call__(self,seq):
        seq = seq[:self.seq_len].lower()
        return np.array([self.nuc_d[x] for x in seq])

class SampleNoGradient(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        dist = torch.distributions.Categorical(x[...,0])
        pwm_sample = torch.nn.functional.one_hot(dist.sample())
        return pwm_sample
    @staticmethod
    def backward(ctx, g):
        return g 

class CNN(nn.Module):

    def __init__(self, seq_len=44):
        super(CNN, self).__init__()
        self.seq_len=seq_len
        self.model=self.cnn_model()
    
    def cnn_model(self,):
        model = nn.Sequential()
        model.append(nn.Conv1d(in_channels=4, out_channels=100, kernel_size=(5,), padding=2))
        model.append(nn.ReLU())
        model.append(nn.Conv1d(in_channels=100, out_channels=100, kernel_size=(5,), padding=2))
        model.append(nn.ReLU())
        model.append(nn.Dropout(p=0.3))
        model.append(nn.Conv1d(in_channels=100, out_channels=200, kernel_size=(5,), padding=2))
        model.append(nn.ReLU())
        model.append(nn.BatchNorm1d(200))
        model.append(nn.Dropout(p=0.3))
        model.append(nn.Flatten())
        model.append(nn.Linear(200*seq_len,100))
        model.append(nn.ReLU())
        model.append(nn.Dropout(p=0.3))
        model.add_module("linear",nn.Linear(100,1))
        return model
    def forward(self, x):
        x= torch.permute(x,(0,2,1))
        return self.model(x)
    
class MyInit(object):
    def __init__(self, templates,seq_length=26,p_init=0.5):
        self.templates = templates
        self.seq_length = seq_length
        self.p_init = p_init

    def __call__(self, module):
        encoder = OneHotEncoder(self.seq_length)
        if hasattr(module, 'template'):
            onehot_templates = np.concatenate([encoder(template).reshape((1, self.seq_length, 4, 1)) 
                                               for template in self.templates], axis=0)
            for i in range(len(self.templates)) :
                template = self.templates[i]
                for j in range(len(template)) :
                    if template[j] != 'N' :
                        if template[j] != 'X' :
                            nt_ix = np.argmax(onehot_templates[i, j, :, 0])
                            onehot_templates[i, j, :, :] = -4
                            onehot_templates[i, j, nt_ix, :] = 10
                        else :
                            onehot_templates[i, j, :, :] = -1
            module.template.data = torch.tensor(onehot_templates) 
        if hasattr(module, 'mask'):
            onehot_masks = np.zeros((len(self.templates), self.seq_length, 4, 1))
            for i in range(len(self.templates)) :
                template = self.templates[i]
                for j in range(len(template)) :
                    if template[j] == 'N' :
                        onehot_masks[i, j, :, :] = 1.0
            module.mask.data = torch.tensor(onehot_masks)
        if hasattr(module, 'pwm'):
            on_logit = np.log(self.p_init / (1. - self.p_init))
            p_off = (1. - self.p_init) / 3.
            off_logit = np.log(p_off / (1. - p_off))
            nn.init.xavier_uniform_(module.pwm.data)
        

class DNAPWM(nn.Module):
    
    def __init__(self,n_sequences=10,seq_length=26):
        super(DNAPWM, self).__init__()
        self.n_sequences=n_sequences
        self.seq_length=seq_length
        self.template = nn.Parameter(torch.zeros((n_sequences, seq_length, 4, 1)),requires_grad=False)
        self.mask = nn.Parameter(torch.zeros((n_sequences, seq_length, 4, 1)),requires_grad=False)
        self.pwm = nn.Parameter(torch.randn((n_sequences, seq_length, 4, 1)),requires_grad=True)
        self.predictor = nn.Sequential(
            nn.Conv1d(in_channels=4, out_channels=100, kernel_size=(5,), padding=2),
            nn.ReLU(),
            nn.Conv1d(in_channels=100, out_channels=100, kernel_size=(5,), padding=2),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Conv1d(in_channels=100, out_channels=200, kernel_size=(5,), padding=2),
            nn.ReLU(),
            nn.BatchNorm1d(200),
            nn.Dropout(p=0.3),
            nn.Flatten(),
            nn.Linear(200*118,100),
            nn.ReLU(),
            nn.Dropout(p=0.3))
        self.predictor.add_module("linear",nn.Linear(100,1))
        self.init()
    
    def init(self,fpath="/Users/john/data/models/sev_pl3_5.pt"):
        save_model = torch.load(fpath)
        modules = {}
        for name, module in save_model.model.named_modules():
            modules[name] = module
        for name,module in self.predictor.named_modules():
            if isinstance(module,(nn.Conv1d,nn.Linear,nn.BatchNorm1d)):
                module.weight.data.copy_(modules[name].weight.data)
                module.weight.requires_grad=False
                module.bias.data.copy_(modules[name].bias.data)
                module.bias.requires_grad=False
            if isinstance(module,nn.BatchNorm1d):
                module.running_mean.data.copy_(modules[name].running_mean.data)
                module.running_mean.requires_grad=False
                module.running_var.data.copy_(modules[name].running_var.data)
                module.running_var.requires_grad=False
    
    def forward(self,):
        pwm_logits = self.pwm * self.mask + self.template
        pwm = nn.functional.softmax(pwm_logits, dim=2)
        if self.training:
#             dist = torch.distributions.Categorical(pwm[...,0])
#             pwm_sample = torch.nn.functional.one_hot(dist.sample())
            pwm_sample=SampleNoGradient.apply(pwm)
        else:
            sample = torch.argmax(pwm[...,0],dim=2)
            pwm_sample = torch.nn.functional.one_hot(sample)
        pwm_sample = torch.unsqueeze(pwm_sample,dim=3)
        return pwm_logits, pwm, pwm_sample
    
    def loss(self,):
        pwm_logits, pwm, pwm_sample = self.forward()
        x = pwm if self.training else pwm_sample
        x = x[...,0].permute((0,2,1)).to(torch.float)
        return self.predictor(x)

    
seq_length = 118
original_seq = 'atcccgggtgaggcatcccaccatcctcagtcacagagagacccaatctaccatcagcatcagccagtaaagattaagaaaaacttagggtgaaagaaatttcacctaacacggcgca'
original_seq=original_seq.upper()
prefix,suffix = original_seq[:72],original_seq[96:]
templates = [prefix+"N"*24+suffix] *10
# templates = ["N"*seq_length] *10
model = DNAPWM(10,seq_length)
model.apply(MyInit(templates,seq_length=seq_length))
criterion = MSELoss(target=0)
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-06)
# model.predictor.eval()
x = model.loss()
print(x)
# x
iterations = 10000
epoch_loss=[]
for it in tqdm.tqdm(range(iterations)):
    model.train()
    outputs = model.loss()
    loss = criterion(outputs)
    epoch_loss.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (it+1) % 1000 ==0:
        loss = np.average(epoch_loss)
        epoch_loss=[]
        print(f"Iterations:{it+1} Loss:{loss}")


tensor([[-1.7554],
        [-2.1592],
        [-1.3452],
        [-1.3751],
        [-1.3177],
        [-1.2945],
        [-1.7633],
        [-1.8098],
        [-1.7569],
        [-0.3740]], grad_fn=<AddmmBackward0>)


 10%|███████████▎                                                                                                     | 1005/10000 [00:30<04:50, 30.99it/s]

Iterations:1000 Loss:1.6879672874212266


 20%|██████████████████████▋                                                                                          | 2008/10000 [00:58<03:10, 42.02it/s]

Iterations:2000 Loss:0.3106649994617328


 30%|█████████████████████████████████▉                                                                               | 3007/10000 [01:28<02:54, 40.01it/s]

Iterations:3000 Loss:0.0429533090214245


 40%|█████████████████████████████████████████████▏                                                                   | 4004/10000 [01:55<04:29, 22.22it/s]

Iterations:4000 Loss:0.036083709297236054


 50%|████████████████████████████████████████████████████████▌                                                        | 5004/10000 [02:26<02:38, 31.45it/s]

Iterations:5000 Loss:0.03249288568343036


 60%|███████████████████████████████████████████████████████████████████▊                                             | 6006/10000 [02:52<01:38, 40.70it/s]

Iterations:6000 Loss:0.03009747034916654


 70%|███████████████████████████████████████████████████████████████████████████████▏                                 | 7007/10000 [03:23<01:29, 33.38it/s]

Iterations:7000 Loss:0.027860181177733465


 80%|██████████████████████████████████████████████████████████████████████████████████████████▍                      | 8003/10000 [03:57<01:01, 32.39it/s]

Iterations:8000 Loss:0.025588324351701885


 87%|██████████████████████████████████████████████████████████████████████████████████████████████████▎              | 8698/10000 [04:19<00:31, 41.16it/s]

In [None]:
import pandas as pd
import Levenshtein
_,pwm,seqs = model()
# model.train()
sample = torch.nn.functional.one_hot(torch.argmax(pwm[...,0],dim=2))
sample=sample.permute((0,2,1)).to(torch.float)
x= seqs[...,0].permute((0,2,1)).to(torch.float)
chars = "ACGT"
seqs_x = torch.argmax(pwm[...,0],dim=2)
seqs = ["".join([chars[i] for i in seq]) for seq in seqs_x]
dists = [Levenshtein.distance(original_seq, s) for s in seqs ]
preds = model.predictor(sample).detach().numpy()[:,0]
pd.DataFrame(data={"seq":seqs,"distance":dists,"preds":preds})

In [None]:
original_seq = 'atcccgggtgaggcatcccaccatcctcagtcacagagagacccaatctaccatcagcatcagccagtaaagattaagaaaaacttagggtgaaagaaatttcacctaacacggcgca'
original_seq=original_seq.upper()
encoder = OneHotEncoder(118)
x_test = torch.tensor(np.array([encoder(original_seq)]),dtype=torch.float)
model.predictor.eval()
model.predictor(torch.permute(x_test,(0,2,1)))

In [None]:
visualisation = {}

def hook_fn(m, i, o):
  visualisation[m] = o 

def get_all_layers(net):
  for name, layer in net._modules.items():
    print(name)
    #If it is a sequential, don't register a hook on it
    # but recursively register hook on all it's module children
    if isinstance(layer, nn.Sequential):
      get_all_layers(layer)
    else:
      # it's a non sequential. Register a hook
      layer.register_forward_hook(hook_fn)

# get_all_layers(model)
# model.loss()
# visualisation


for name, para in model.named_parameters():
        print(name)