In [1]:
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import argparse
from tqdm import tqdm  # optional progress bar
import pandas as pd
import torch
from torch import nn
from torch import nn, optim
from torch.nn import functional as F
from collections import OrderedDict
from torch import Tensor
import pprint
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

hyperparams = {
    "num_epochs": 25,
    "batch_size": 400,
    "learning_rate": 1e-4,
}

cuda


In [2]:
class HistoneDataset(Dataset):
    def __init__(self, input_file, seq_file):
        """
        :param input_file: the data file pathname
        """
        self.id2seq = dict()
        seqs = np.load(seq_file, allow_pickle=True)
        for i in range(seqs.shape[0]):
            row = seqs[i]
            row[1].extend([0, 0])
            self.id2seq[int(row[0])] = torch.tensor(row[1]).view(-1, 5).float()

        # [50, 16000, 100, 7]
        # [cell_types, genes, bins, (columns)]
        # columns = GeneID, H3K27me3, H3K36me3, H3K4me1, H3K4me3, H3K9me3, Expression Value (same for entire bin)
        # columns 0: GeneId, 1-5: Histone Marks, 6: Expression Value
        npdata = np.load(input_file)
        cell_types = npdata.files

        # [cell_types, genes, bins, histomes]
        input = []
        # [cell_types, genes, expression]
        output = []
        # [cell_types, genes, expression]
        ids = []
        # types
        types = []

        for cell in cell_types:
            cell_data = npdata[cell]
            id = cell_data[:, 0, 0]
            hm_data = cell_data[:, :, 1:6]
            exp_values = cell_data[:, 0, 6]
            ids.append(id)
            input.append(hm_data)
            output.append(exp_values)
            types.extend([cell] * cell_data.shape[0])

        # [cell_types*genes, bins, histomes]
        input = np.concatenate(input, axis=0)
        # [cell_types*genes, expression]
        output = np.concatenate(output, axis=0)
        ids = np.concatenate(ids, axis=0)

        self.x = []
        self.y = []
        self.id = ids
        self.type = np.asarray(types)

        for x in input:
            self.x.append(torch.tensor(x))

        for y in output:
            self.y.append(torch.tensor(y))

    def __len__(self):
        """
        len should return a the length of the dataset

        :return: an integer length of the dataset
        """
        # TODO: Override method to return length of dataset
        return len(self.y)

    def __getitem__(self, idx):
        """
        getitem should return a tuple or dictionary of the data at some index
        In this case, you should return your original and target sentence and
        anything else you may need in training/validation/testing.

        :param idx: the index for retrieval

        :return: tuple or dictionary of the data
        """
        # TODO: Override method to return the items in dataset
        item = {
            "cell_type": self.type[idx],
            "id": self.id[idx],
            "x": torch.cat((self.x[idx], self.id2seq[int(self.id[idx])]), dim=0),
            "y": self.y[idx],
        }
        return item

In [3]:
class DenseLayer(nn.Module):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
        super(DenseLayer, self).__init__()
        self.norm1 = nn.BatchNorm2d(num_input_features)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)
        self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
        self.dropout = nn.Dropout(p=drop_rate)

    def bn_function(self, inputs):
        # type: (List[Tensor]) -> Tensor
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))
        return bottleneck_output

    def forward(self, input):
        if isinstance(input, Tensor):
            prev_features = [input]
        else:
            prev_features = input

        bottleneck_output = self.bn_function(prev_features)

        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
        new_features = self.dropout(new_features)
        return new_features


class DenseBlock(nn.ModuleDict):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
        super(DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = DenseLayer(
                num_input_features + i * growth_rate,
                growth_rate=growth_rate,
                bn_size=bn_size,
                drop_rate=drop_rate,
            )
            self.add_module('denselayer%d' % (i + 1), layer)

    def forward(self, init_features):
        features = [init_features]
        for name, layer in self.items():
            new_features = layer(features)
            features.append(new_features)
        return torch.cat(features, 1)


class Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features):
        super(Transition, self).__init__()
        self.norm = nn.BatchNorm2d(num_input_features)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)
        # prevent output from shrinking
        self.pool = nn.AvgPool2d(kernel_size=2, stride=(2, 1))



class DenseNet(nn.Module):
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
                 num_init_features=64, bn_size=4, drop_rate=0.5, num_classes=1, theta=0.5):

        super(DenseNet, self).__init__()

        # First convolution
        self.features = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(1, num_init_features, kernel_size=7, stride=(2, 1),
                                padding=3, bias=False)),
            ('norm0', nn.BatchNorm2d(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=(2, 1), padding=1)),
        ]))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = DenseBlock(
                num_layers=num_layers,
                num_input_features=num_features,
                bn_size=bn_size,
                growth_rate=growth_rate,
                drop_rate=drop_rate,
            )
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = Transition(num_input_features=num_features,
                                    num_output_features=int(num_features * theta))
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = int(num_features * theta)

        # Final batch norm
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
        out = self.classifier(out)
        return out

In [4]:
def train(model, train_loader, validate_loader, location):
    print("starting train")

    loss_fn = torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')
    optimizer = optim.Adam(model.parameters(), hyperparams['learning_rate'])

    model = model.train()

    for epoch in range(hyperparams['num_epochs']):
        model = model.train()
        losses = []
        for batch in tqdm(train_loader):
            x = batch['x']
            y = batch['y']
            x = x.unsqueeze(1)
            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad()
            y_pred = model(x)

            loss = loss_fn(y_pred.squeeze(1), y)

            loss.backward()  # calculate gradients
            optimizer.step()  # update model weights

            losses.insert(0, loss.item())
            losses = losses[:100]
        print("saving model")
        torch.save(model.state_dict(), location + 'model' + str(epoch) + '.pt')
        print(epoch, "epoch loss:", np.mean(losses))
        validate(model, validate_loader)


def validate(model, validate_loader):
    print("starting validation")
    loss_fn = torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')

    model = model.eval()
    losses = []

    for batch in tqdm(validate_loader):
        x = batch['x']
        y = batch['y']
        x = x.unsqueeze(1)
        x = x.to(device)
        y = y.to(device)

        y_pred = model(x)

        loss = loss_fn(y_pred.squeeze(1), y)

        losses.append(loss.item())

    print("mean loss:", np.mean(losses))


def test(model, test_loader, location):
    print("starting test")
    model = model.eval()
    classification = []

    for batch in tqdm(test_loader):
        x = batch['x']
        cell_type = batch['cell_type']
        id = batch['id']
        x = x.unsqueeze(1)
        x = x.to(device)

        y_pred = model(x)
        for i in range(y_pred.size()[0]):
            # print(cell_type[i].item(), id[i].item(), y_pred[i].item())
            classification.append((cell_type[i].item() + "_" + str(int(id[i].item())), str(y_pred[i].item())))

    df = pd.DataFrame(classification, columns=['id', 'expression'])
    df.to_csv(location + 'submission.csv', index=False)

In [None]:
train_model = True
test_model = False
load_model = False
save_model = True

model = DenseNet(4, (2, 3, 3, 2), 16).to(device)
location = './'
test_file = location + 'data/eval.npz'
train_file = location + 'data/train.npz'
seq_file = location + 'data/encoded.npy'

train_dataset = None
validate_dataset = None
test_dataset = None

print("gathering train data")
train_loader = None
if train_model:
    dataset = HistoneDataset(train_file, seq_file)

    split_amount = int(len(dataset) * 0.9)

    train_dataset, validate_dataset = random_split(
        dataset, (split_amount, len(dataset) - split_amount))
    
    train_loader = DataLoader(
        train_dataset, batch_size=hyperparams['batch_size'], shuffle=True
    )
    validate_loader = DataLoader(
        validate_dataset, batch_size=hyperparams['batch_size'], shuffle=True
    )

print("gathering test data", device)
if test_model:
    test_dataset = HistoneDataset(test_file, seq_file)
    test_loader = DataLoader(test_dataset, batch_size=hyperparams['batch_size'])

if load_model:
    print("loading saved model...")
    model.load_state_dict(torch.load(location + 'model.pt'))
if train_model:
    print("running training loop...")
    train(model, train_loader, validate_loader, location)
    validate(model, validate_loader)
if save_model:
    print("saving model...")
    torch.save(model.state_dict(), location + 'model.pt')
if test_model:
    print("running testing loop...")
    test(model, test_loader, location)
    

gathering train data


  0%|                                                                                         | 0/1800 [00:00<?, ?it/s]

gathering test data cuda
running training loop...
starting train


100%|██████████████████████████████████████████████████████████████████████████████| 1800/1800 [02:39<00:00, 11.32it/s]
  2%|█▋                                                                                | 4/200 [00:00<00:06, 32.09it/s]

saving model
0 epoch loss: 3.3236157298088074
starting validation


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:05<00:00, 38.00it/s]
  0%|                                                                                 | 1/1800 [00:00<05:13,  5.73it/s]

mean loss: 3.3372390055656433


100%|██████████████████████████████████████████████████████████████████████████████| 1800/1800 [02:40<00:00, 11.19it/s]
  2%|█▋                                                                                | 4/200 [00:00<00:05, 37.14it/s]

saving model
1 epoch loss: 3.2935297012329103
starting validation


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:05<00:00, 39.04it/s]
  0%|                                                                                         | 0/1800 [00:00<?, ?it/s]

mean loss: 3.304830753803253


100%|██████████████████████████████████████████████████████████████████████████████| 1800/1800 [02:39<00:00, 11.25it/s]
  2%|█▋                                                                                | 4/200 [00:00<00:05, 38.21it/s]

saving model
2 epoch loss: 3.286192448139191
starting validation


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:05<00:00, 38.69it/s]
  0%|                                                                                 | 1/1800 [00:00<04:54,  6.11it/s]

mean loss: 3.3054303514957426


100%|██████████████████████████████████████████████████████████████████████████████| 1800/1800 [02:38<00:00, 11.38it/s]
  2%|█▋                                                                                | 4/200 [00:00<00:05, 35.18it/s]

saving model
3 epoch loss: 3.3425286197662354
starting validation


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:05<00:00, 39.09it/s]
  0%|                                                                                 | 1/1800 [00:00<04:23,  6.82it/s]

mean loss: 3.2940007722377778


100%|██████████████████████████████████████████████████████████████████████████████| 1800/1800 [02:39<00:00, 11.31it/s]
  2%|█▋                                                                                | 4/200 [00:00<00:05, 35.49it/s]

saving model
4 epoch loss: 3.2971162486076353
starting validation


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:05<00:00, 38.73it/s]
  0%|                                                                                 | 1/1800 [00:00<04:50,  6.19it/s]

mean loss: 3.287163416147232


100%|██████████████████████████████████████████████████████████████████████████████| 1800/1800 [02:40<00:00, 11.23it/s]
  2%|█▋                                                                                | 4/200 [00:00<00:05, 35.49it/s]

saving model
5 epoch loss: 3.249870662689209
starting validation


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:05<00:00, 38.71it/s]
  0%|                                                                                 | 1/1800 [00:00<04:52,  6.15it/s]

mean loss: 3.282512205839157


100%|██████████████████████████████████████████████████████████████████████████████| 1800/1800 [02:38<00:00, 11.33it/s]
  2%|█▋                                                                                | 4/200 [00:00<00:05, 37.14it/s]

saving model
6 epoch loss: 3.28410178899765
starting validation


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:05<00:00, 38.16it/s]
  0%|                                                                                 | 1/1800 [00:00<04:48,  6.23it/s]

mean loss: 3.2658091259002684


100%|██████████████████████████████████████████████████████████████████████████████| 1800/1800 [02:39<00:00, 11.29it/s]
  2%|█▋                                                                                | 4/200 [00:00<00:06, 32.34it/s]

saving model
7 epoch loss: 3.2826910424232483
starting validation


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:05<00:00, 38.20it/s]
  0%|                                                                                 | 1/1800 [00:00<04:48,  6.23it/s]

mean loss: 3.264050382375717


100%|██████████████████████████████████████████████████████████████████████████████| 1800/1800 [02:40<00:00, 11.24it/s]
  2%|█▋                                                                                | 4/200 [00:00<00:05, 33.99it/s]

saving model
8 epoch loss: 3.2231621241569517
starting validation


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:05<00:00, 38.43it/s]
  0%|                                                                                 | 1/1800 [00:00<04:59,  6.00it/s]

mean loss: 3.2536452066898347


100%|██████████████████████████████████████████████████████████████████████████████| 1800/1800 [02:40<00:00, 11.18it/s]
  2%|█▋                                                                                | 4/200 [00:00<00:06, 31.83it/s]

saving model
9 epoch loss: 3.2431330418586732
starting validation


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:05<00:00, 38.18it/s]
  0%|                                                                                 | 1/1800 [00:00<04:27,  6.73it/s]

mean loss: 3.247651263475418


 97%|███████████████████████████████████████████████████████████████████████████▉  | 1751/1800 [02:34<00:04, 11.49it/s]