# 1. Overview

# 2. Datasets

# 3. Implementation

## 3. 1 Setup

### 3.1.1 Importing libraries

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import pickle
import json
import itertools
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm, trange

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Running on the GPU")
else:
    device = torch.device("cpu")
    print("Running on the CPU")

### 3.1.2 Loading Data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

%cd drive/My\ Drive
!unzip data.zip

### 3.1.2 Reading Data

In [None]:
class TFBindingDataset(Dataset):
    def __init__(self, dataset_path):
        super(TFBindingDataset, self).__init__()
        return
        

    def __getitem__(self, index):
        return

    def __len__(self):
        return len(self.indices)


## 3. 2 Model

## 3.3 Plots & Evaluation Functions

## 3.4 Training & Test

In [None]:
def train(net, train_loader, val_loader, optimizer, criterion, epoch_num=100):
    train_log = []
    val_log = []

    for epoch in range(1, epoch_num+1):
        train_loss = []
        net.train()
        for (xd_input, xt_input, labels) in tqdm(train_loader, desc='Training epoch ' + str(epoch), leave=False):        
            xd_input, xt_input, labels = xd_input.to(device), xt_input.to(device), labels.to(device)        
            optimizer.zero_grad()
            outputs = net(xd_input.long(), xt_input.long())
            loss = criterion(outputs.float(), labels.float())      
            loss.backward()                
            optimizer.step()        
            train_loss.append(loss.item())
        train_log.append(np.mean(train_loss))
        print('============ epoch %d =============' %epoch)
        print('train loss: %.3f' % np.mean(train_loss), flush=True, end='')

        test_loss = []
        net.eval()
        with torch.no_grad():                
            for (xd_input, xt_input, labels) in tqdm(val_loader, desc='Validation ', leave=False):         
                xd_input, xt_input, labels = xd_input.to(device), xt_input.to(device), labels.to(device)        
                outputs = net(xd_input.long(), xt_input.long())
                loss = criterion(outputs.float(), labels.float())            
                test_loss.append(loss.item())
        val_log.append(np.mean(test_loss)) 
        print('validation loss: %.3f' % np.mean(test_loss), flush=True, end='')
    return train_log, val_log

In [None]:
def test(net, test_loader):
    y_true = []
    y_pred = []
    net.eval()
    with torch.no_grad():                
        for (xd_input, xt_input, labels) in tqdm(test_loader, desc='Test ', leave=False):         
            xd_input, xt_input, labels = xd_input.to(device), xt_input.to(device), labels.to(device)        
            outputs = net(xd_input.long(), xt_input.long())
            y_true.extend(labels.float().detach().cpu().numpy())
            y_pred.extend(outputs.float().detach().cpu().numpy())
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    return y_true, y_pred

# 4. Results