# Implementing DCA in pytorch

Source CC function in [seqmodel](https://github.com/sokrypton/seqmodels/blob/master/seqmodels.ipynb)
Reimplemeting tensorflow code in pytorch

In [1]:
import numpy as np
import torch

import read_config
from dataloader import MSADataset, OneHotTransform

In [2]:
config = read_config.Config("../config2d.yaml")
dataset = MSADataset(config.aligned_msa_fullpath, transform=OneHotTransform(21, flatten=False))

N = len(dataset)

protein_seq, weight = dataset[0]
ncat = 21
L = protein_seq.shape[0]
ncol = L

N, L

(14441, 559)

In [3]:
protein_seq.shape

torch.Size([559, 21])

In [4]:
ncol, ncat

(559, 21)

In [5]:
batch_size = config.batch_size
#batch_size = len(dataset) ## to load in the whole dataset
msa = torch.utils.data.DataLoader(dataset, batch_size)

for _, msa_data in enumerate(msa):
    data  = msa_data[0]
    seq_weights = msa_data[1]
    break
    
print(f"Data.shape = {data.shape}")
print(f"Weights.shape = {seq_weights.shape}")

Data.shape = torch.Size([128, 559, 21])
Weights.shape = torch.Size([128])


In [6]:
device = config.device

In [7]:
bias = torch.zeros((ncol,ncat), dtype=torch.float, requires_grad=True, device=device)
w = torch.zeros((ncol, ncat, ncol, ncat), dtype=torch.float, requires_grad=True, device=device)

optimizer = torch.optim.SGD([bias, w], lr=config.learning_rate)

# we do not want weights between the various nodes in a given position. 
# i.e. weights between nodes (i, a) and (j, b) only exist if i not = j
# so set these weights to zero
w_eye = w * torch.reshape(1 - torch.eye(ncol), (ncol,1,ncol, 1))
# symmetrize w so that the weight between (i,a) and (j, b) is the
# same as the weight between (j, b) and (i, a)
weights = w_eye + w_eye.permute(2,3,0,1)

In [8]:
data = data.to(device)
seq_weights = seq_weights.to(device)

In [9]:
data.shape

torch.Size([128, 559, 21])

In [10]:
weights.shape

torch.Size([559, 21, 559, 21])

In [11]:
softmax_func = torch.nn.Softmax(-1)
loss_func = torch.nn.CrossEntropyLoss()

In [12]:
data_logit = torch.tensordot(data, weights, 2) + bias

data_logit.shape

torch.Size([128, 559, 21])

In [13]:
data_pred = softmax_func(data_logit)

data_pred.shape

torch.Size([128, 559, 21])

In [14]:
data.shape

torch.Size([128, 559, 21])

In [15]:
loss =  loss_func(data.permute(0,2,1),data_logit.permute(0,2,1))

ValueError: Expected target size (128, 559), got torch.Size([128, 21, 559])