In [7]:
import pandas as pd

In [8]:
# (cell_type, sm_name, gene, d0_val, d1_val, d2_val) -> (de_val)

In [9]:
# DE labels
kaggle_train_de_df = pd.read_parquet('data/de_train.parquet')
genes = sorted(list(set(kaggle_train_de_df.columns.tolist()) - set(["cell_type", "sm_name", "sm_lincs_id", "SMILES", "control"])))
del kaggle_train_de_df

In [10]:
ddde = pd.read_parquet('data/ddde.parquet')

In [11]:
cell_types = sorted(ddde.cell_type.unique().tolist())
sm_names = sorted(ddde.sm_name.unique().tolist())

In [12]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

gene_count = len(genes)

class DDDEDataset(Dataset):
    def __init__(self, dfr):
        self.data = []
        self.labels = []
        cols = dfr.columns.tolist()
        n2i = {name: i for i, name in enumerate(cols)}
        for d in dfr.to_records(index=False):
            ci = cell_types.index(d[n2i["cell_type"]])
            si = sm_names.index(d[n2i["sm_name"]])
            exp = [d[n2i["%s_d2" % g]] for g in genes]
            self.data.append((ci, si, exp))
            self.labels.append([d[n2i["%s_de" % g]] for g in genes])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        ci, si, exp = self.data[idx]
        return {
            'cell_type': ci,
            'sm_name': si,
            'exp': torch.tensor(exp, dtype=torch.float32),
            'label': torch.tensor(self.labels[idx], dtype=torch.float32)
        }
        

In [13]:

# Create an instance of the custom dataset
custom_dataset = DDDEDataset(ddde)


In [14]:
batch_size = 4
shuffle = True
data_loader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4)

In [15]:
for batch in data_loader:
    print(batch["exp"].shape)
    break

torch.Size([4, 18211])


In [17]:
import torch
import torch.nn as nn
import torch.optim as optim

class DECNN(nn.Module):
    def __init__(self, cell_type_size, sm_name_size, cell_type_dim=16, sm_name_dim=32, exp_dim=256, hidden_size=256):
        super(DECNN, self).__init__()

        self.cell_type_embedding = nn.Embedding(cell_type_size, cell_type_dim)
        self.sm_name_embedding = nn.Embedding(sm_name_size, sm_name_dim)
        
        self.fc_cont = nn.Linear(18211, exp_dim)
        self.fc1 = nn.Linear(cell_type_dim + sm_name_dim + exp_dim, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 18211)

        self.dropout_fcc = nn.Dropout(p=0.5)
        self.dropout_fc1 = nn.Dropout(p=0.5) 

    def forward(self, cell_type, sm_name, exp):
        # Embedding categorical variable
        x_cell_type = self.cell_type_embedding(cell_type)
        x_sm_name = self.sm_name_embedding(sm_name)

        x_cont = torch.relu(self.fc_cont(exp))
        x_cont = self.dropout_fcc(x_cont)

        # Concatenate embeddings with continuous input
        x = torch.cat([x_cell_type, x_sm_name, x_cont], dim=1)

        # Feedforward layers
        x = torch.relu(self.fc1(x))
        x = self.dropout_fc1(x)
        x = self.fc2(x)
        
        return x

In [18]:
from torch.nn.parallel import DataParallel

net = DECNN(cell_type_size=len(cell_types), sm_name_size=len(sm_names))
net = DataParallel(net).cuda()

In [19]:
net

DataParallel(
  (module): DECNN(
    (cell_type_embedding): Embedding(6, 16)
    (sm_name_embedding): Embedding(146, 32)
    (fc_cont): Linear(in_features=18211, out_features=256, bias=True)
    (fc1): Linear(in_features=304, out_features=256, bias=True)
    (fc2): Linear(in_features=256, out_features=18211, bias=True)
    (dropout_fcc): Dropout(p=0.5, inplace=False)
    (dropout_fc1): Dropout(p=0.5, inplace=False)
  )
)

In [20]:
import tqdm as tq

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

In [37]:
# Training loop
num_epochs = 1
for epoch in range(num_epochs):
    net.train()  # Set the model to training mode
    for batch in tq.tqdm(data_loader):
        optimizer.zero_grad()  # Zero the gradients
        cell_type = batch['cell_type'].cuda()
        sm_name = batch['sm_name'].cuda()
        exp = batch['exp'].cuda()
        targets = batch['label'].cuda()
        
        outputs = net(cell_type, sm_name, exp)
        loss = criterion(outputs, targets)
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights
        
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}')

100%|██████████| 154/154 [00:03<00:00, 48.81it/s]

Epoch 1/1, Loss: 0.24229691922664642





In [38]:
torch.save(net.state_dict(), f"models/dcnn_epoch_{epoch + 1}.pt")

In [39]:
ddd_test = pd.read_parquet('data/ddd_test.parquet')

In [40]:
ddd_test

Unnamed: 0,cell_type,sm_name,CTGF,ARHGAP22,KLKB1,MEX3D,CXCL16,MAP4K4,CHCHD5,AL031123.2,...,TAS2R14,AC092164.1,AC108673.2,FAM129A,SLC43A3,FAS,SPATS2,FUNDC1,LOH12CR2,KIFC3
0,B cells,5-(9-Isopropyl-8-methyl-2-morpholino-9H-purin-...,-4.949943e-06,-1.479604e-06,-4.760922e-08,3.758273e-07,-3.925180e-07,-6.653557e-05,-0.000027,7.799092e-07,...,-0.000003,-2.209602e-07,-5.017562e-07,0.000037,-3.545329e-06,-0.000013,-0.000023,-0.000021,-5.713630e-07,6.889568e-08
1,B cells,ABT-199 (GDC-0199),-4.949943e-06,-1.479604e-06,-4.760922e-08,-5.115271e-07,-2.069160e-06,-6.677311e-05,-0.000026,-1.074452e-07,...,-0.000003,-2.209602e-07,-5.017562e-07,0.000032,-3.593741e-06,-0.000010,-0.000022,-0.000021,-1.221373e-07,-4.004679e-07
2,B cells,ABT737,-4.949943e-06,-5.485636e-07,-4.760922e-08,-5.115271e-07,-5.383019e-07,-6.817715e-05,-0.000029,-1.074452e-07,...,-0.000003,-2.209602e-07,-3.623578e-08,0.000033,-3.078738e-06,-0.000015,-0.000022,-0.000021,-5.713630e-07,5.305723e-07
3,B cells,AMD-070 (hydrochloride),-4.949943e-06,-9.769299e-07,-4.760922e-08,-5.115271e-07,-2.448985e-06,-6.881478e-05,-0.000026,8.979018e-07,...,-0.000003,7.521344e-07,-5.017562e-07,0.000035,-1.185708e-06,-0.000011,-0.000023,-0.000021,-5.713630e-07,-4.004679e-07
4,B cells,AT 7867,-4.949943e-06,-1.479604e-06,-4.760922e-08,-5.115271e-07,-1.221097e-07,-7.042482e-05,-0.000020,3.464027e-07,...,-0.000003,2.328877e-07,-5.017562e-07,0.000038,1.664286e-07,-0.000018,-0.000023,-0.000020,-5.713630e-07,5.433109e-07
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
250,Myeloid cells,Vandetanib,-1.187022e-08,-1.775823e-05,-5.540405e-08,-3.373618e-06,-1.046163e-04,-1.896645e-06,-0.000046,-2.532030e-06,...,-0.000003,-5.076063e-07,-6.335904e-07,-0.000078,-6.721181e-05,-0.000011,-0.000013,-0.000019,-9.455799e-07,-1.662527e-05
251,Myeloid cells,Vanoxerine,-1.187022e-08,-1.737945e-05,-5.540405e-08,-3.373618e-06,-1.034851e-04,-3.399638e-06,-0.000043,-1.648393e-06,...,-0.000003,-5.076063e-07,-6.335904e-07,-0.000074,-6.827758e-05,-0.000007,-0.000012,-0.000019,-9.455799e-07,-1.662527e-05
252,Myeloid cells,Vardenafil,-1.187022e-08,-1.826309e-05,-5.540405e-08,-2.927031e-06,-1.037370e-04,-5.232093e-06,-0.000047,-1.638855e-06,...,-0.000003,-5.076063e-07,-6.335904e-07,-0.000084,-6.959435e-05,-0.000007,-0.000013,-0.000019,-9.455799e-07,-1.562148e-05
253,Myeloid cells,Vorinostat,-1.187022e-08,-1.733847e-05,-5.540405e-08,-3.373618e-06,-1.051561e-04,-3.522222e-06,-0.000048,-2.532030e-06,...,-0.000003,6.333443e-09,-6.335904e-07,-0.000088,-7.147978e-05,-0.000010,-0.000013,-0.000019,-9.455799e-07,-1.662527e-05


In [41]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

gene_count = len(genes)

class DDDTestDataset(Dataset):
    def __init__(self, dfr):
        self.data = []
        cols = dfr.columns.tolist()
        n2i = {name: i for i, name in enumerate(cols)}
        for d in dfr.to_records(index=False):
            ci = cell_types.index(d[n2i["cell_type"]])
            si = sm_names.index(d[n2i["sm_name"]])
            exp = [d[n2i[g]] for g in genes]
            self.data.append((ci, si, exp))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        ci, si, exp = self.data[idx]
        return {
            'cell_type': ci,
            'sm_name': si,
            'exp': torch.tensor(exp, dtype=torch.float32)
        }


In [42]:
test_dataset = DDDTestDataset(ddd_test)

In [43]:
batch_size = 4
shuffle = False
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4)

In [44]:
net.eval()
with torch.no_grad(), open("dcnn_test.tsv", "w") as f:
    val_loss = 0.0
    for batch in test_loader:
        cell_type = batch['cell_type'].cuda()
        sm_name = batch['sm_name'].cuda()
        exp = batch['exp'].cuda()
        
        outputs = net(cell_type, sm_name, exp)

        for ci, si, de in zip(cell_type.cpu().tolist(), sm_name.cpu().tolist(), outputs.squeeze().cpu().tolist()):
            f.write("%s\n" % "\t".join([cell_types[ci], sm_names[si]] + [str(d) for d in de]))
        

In [45]:
id_map = pd.read_csv('data/id_map.csv', delimiter=',')

In [46]:
sample_submission = pd.read_csv('data/sample_submission.csv', delimiter=',')

In [47]:
# sample_submission.columns.tolist()

In [48]:
de_test = pd.read_csv('dcnn_test.tsv', delimiter='\t', names=["cell_type", "sm_name"] + genes)

In [49]:
de_test = pd.merge(de_test, id_map, on=["cell_type", "sm_name"], how="inner")

In [50]:
de_test.drop("cell_type", axis=1, inplace=True)
de_test.drop("sm_name", axis=1, inplace=True)

In [51]:
de_test.reset_index(inplace=True)

In [52]:
de_test

Unnamed: 0,index,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,A4GALT,AAAS,AACS,AAGAB,...,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1,id
0,0,0.155849,-0.174587,0.103509,0.339105,1.523926,0.461737,-0.024070,0.030861,0.091644,...,0.115310,-0.217786,0.441188,-0.045563,-0.235614,0.264368,0.055558,0.051327,-0.084047,0
1,1,0.619236,-0.009944,0.823402,0.718217,3.846050,2.055936,-0.111335,0.340667,0.139396,...,0.393429,-0.209659,1.127342,0.480144,0.114973,0.314412,0.122322,0.165130,-0.074966,1
2,2,0.853685,0.409799,0.538681,0.596208,3.100556,3.139338,-0.122525,0.564935,-0.031322,...,0.288138,0.091714,1.140084,0.520861,0.394136,0.384511,0.124197,0.514623,-0.280113,2
3,3,0.437619,0.074133,0.373912,0.387558,2.760167,1.281298,0.108854,0.363444,0.194341,...,0.123616,-0.105324,0.777619,0.042021,-0.015463,0.270953,0.196976,-0.192218,-0.015806,3
4,4,0.027279,-0.229642,0.125491,0.325646,1.888218,0.586075,-0.058396,-0.063963,0.181543,...,-0.198765,-0.284037,0.216306,-0.430562,-0.526814,0.180181,0.142282,-0.084016,-0.106153,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
250,250,-0.134192,-0.449634,-1.776406,0.239664,2.679216,-0.344044,-0.297276,-0.301243,-0.109133,...,-0.635444,-0.841229,-0.908861,0.281659,-0.271353,-0.363987,-0.239620,-0.336232,-0.278235,250
251,251,0.033252,-0.275319,-0.221106,0.230196,2.326233,0.420175,-0.153622,-0.126400,-0.323152,...,-0.562089,-0.395020,-0.934260,0.292659,-0.206263,-0.307342,-0.184705,-0.677172,0.034708,251
252,252,-0.310791,-0.456158,-1.480004,-0.144863,1.649695,-0.598495,-0.173415,-0.441047,-0.158660,...,-0.675714,-0.794524,-1.207577,-0.005749,-0.465122,-0.465239,-0.089834,-0.466985,-0.087723,252
253,253,-0.712854,-0.750422,-7.810081,-0.578058,4.479853,-1.376245,0.308214,-1.438605,0.681041,...,-1.661792,-4.541638,-2.781394,-0.717032,-1.745970,-2.600702,-1.447923,-0.306733,-0.708952,253


In [53]:
de_test[sample_submission.columns.tolist()].to_csv('submission.csv', index=False)

In [54]:
!zip s.zip submission.csv

updating: submission.csv (deflated 54%)


In [55]:
de_test[sample_submission.columns.tolist()]

Unnamed: 0,id,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,A4GALT,AAAS,AACS,AAGAB,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
0,0,0.155849,-0.174587,0.103509,0.339105,1.523926,0.461737,-0.024070,0.030861,0.091644,...,0.329050,0.115310,-0.217786,0.441188,-0.045563,-0.235614,0.264368,0.055558,0.051327,-0.084047
1,1,0.619236,-0.009944,0.823402,0.718217,3.846050,2.055936,-0.111335,0.340667,0.139396,...,-0.044588,0.393429,-0.209659,1.127342,0.480144,0.114973,0.314412,0.122322,0.165130,-0.074966
2,2,0.853685,0.409799,0.538681,0.596208,3.100556,3.139338,-0.122525,0.564935,-0.031322,...,-0.263440,0.288138,0.091714,1.140084,0.520861,0.394136,0.384511,0.124197,0.514623,-0.280113
3,3,0.437619,0.074133,0.373912,0.387558,2.760167,1.281298,0.108854,0.363444,0.194341,...,-0.158527,0.123616,-0.105324,0.777619,0.042021,-0.015463,0.270953,0.196976,-0.192218,-0.015806
4,4,0.027279,-0.229642,0.125491,0.325646,1.888218,0.586075,-0.058396,-0.063963,0.181543,...,-0.026008,-0.198765,-0.284037,0.216306,-0.430562,-0.526814,0.180181,0.142282,-0.084016,-0.106153
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
250,250,-0.134192,-0.449634,-1.776406,0.239664,2.679216,-0.344044,-0.297276,-0.301243,-0.109133,...,-0.017711,-0.635444,-0.841229,-0.908861,0.281659,-0.271353,-0.363987,-0.239620,-0.336232,-0.278235
251,251,0.033252,-0.275319,-0.221106,0.230196,2.326233,0.420175,-0.153622,-0.126400,-0.323152,...,-0.357032,-0.562089,-0.395020,-0.934260,0.292659,-0.206263,-0.307342,-0.184705,-0.677172,0.034708
252,252,-0.310791,-0.456158,-1.480004,-0.144863,1.649695,-0.598495,-0.173415,-0.441047,-0.158660,...,-0.076098,-0.675714,-0.794524,-1.207577,-0.005749,-0.465122,-0.465239,-0.089834,-0.466985,-0.087723
253,253,-0.712854,-0.750422,-7.810081,-0.578058,4.479853,-1.376245,0.308214,-1.438605,0.681041,...,-0.494259,-1.661792,-4.541638,-2.781394,-0.717032,-1.745970,-2.600702,-1.447923,-0.306733,-0.708952
