In [1]:
%matplotlib inline

from gensim import models
import numpy as np
import matplotlib.pyplot as plt
import text_utils as tu
import multiprocessing
import os

import torch
import torch.utils.data
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms

import torchtext
import torchtext.vocab as vocab

from tensorboardX import SummaryWriter

# Architecture

In [2]:
class EncoderGRU(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, batch_size, embedding_matrix=None, n_layers=1, seq_len=500, bidirectional=False):
        super(EncoderGRU, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.seq_len = seq_len
        self.bidirectional = bidirectional
        self.batch_size = batch_size
        self.gru = nn.GRU(input_size=self.input_dim, hidden_size=self.hidden_dim, num_layers=n_layers, bidirectional=self.bidirectional)
        
        b = 2 if self.bidirectional else 1
        self.init_hidden_state = nn.Parameter(torch.FloatTensor(self.n_layers * b, self.batch_size, self.hidden_dim).normal_())

        
    def forward(self, x, hidden):
        output, hidden = self.gru(x, hidden)
        return output, hidden
    
    def init_hidden(self):
        return Variable(self.init_hidden_state)

In [3]:
class DecoderGRU(nn.Module):
    
    def __init__(self, input_dim, output_dim, hidden_dim, batch_size, embedding_matrix=None, n_layers=1, seq_len=500, bidirectional=False):
        super(DecoderGRU, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.batch_size = batch_size
        self.seq_len = seq_len
        self.bidirectional = bidirectional
        self.gru = nn.GRU(input_size=self.input_dim, hidden_size=self.hidden_dim, num_layers=n_layers)
        self.out = nn.Linear(self.hidden_dim, self.output_dim)
        self.softmax = nn.LogSoftmax(dim=1)
        
        b = 2 if self.bidirectional else 1
        self.init_hidden_state = nn.Parameter(torch.FloatTensor(self.n_layers * b, self.batch_size, self.hidden_dim).normal_())

    
    def forward(self, x, hidden):
        output, hidden = self.gru(x, hidden)
        output = self.out(output)
        output = self.softmax(output)
        return output, hidden
    
    def init_hidden(self, bs):
        return Variable(self.init_hidden_state)

# Embeddings

In [4]:
raw_reports = np.load('/home/rohanmirchandani/maxwell-pt-test/points.npy')
dirty_reports = [report['body'] for report in raw_reports]
clean_reports, _ = tu.clean_report(dirty_reports, clean=1) # first pass removes \n's and weird characters
tokenised_reports, report_vocab = tu.clean_report(clean_reports, clean=2) # second pass tokenises and builds vocab
vocab, raw_embeddings = tu.load_glove('/home/rohanmirchandani/glove/glove.6B.50d.w2vformat.txt', report_vocab, 50)
vocab['<SOS>'] = raw_embeddings.shape[0]
raw_embeddings = np.vstack((raw_embeddings, np.zeros((1, 50))))
vocab['<EOS>'] = raw_embeddings.shape[0]
raw_embeddings = np.vstack((raw_embeddings, np.ones((1, 50))))
vocab['<UNK>'] = raw_embeddings.shape[0]
raw_embeddings = np.vstack((raw_embeddings, -np.ones((1, 50))))

loaded GloVe


In [5]:
reverse_vocab = {v:k for k, v in vocab.items()}

# Training

In [6]:
exp_name = "seq2seq_unidirectional_learnable_init"
exp_path = "runs/seq2seq/gru/"
tb_path = os.path.join(exp_path, exp_name)
writer = SummaryWriter(log_dir=tb_path)

In [7]:
epochs = 5
criterion = nn.NLLLoss()

input_dim = 50
output_dim = len(vocab)
hidden_dim = 512
output_seq_len = 300
n_layers = 3
batch_size = 32

glove = torch.LongTensor(raw_embeddings)
embeddings = nn.Embedding.from_pretrained(glove, freeze=True)

eye = torch.eye(len(vocab))
one_hots = nn.Embedding.from_pretrained(eye, freeze=True)

E = EncoderGRU(input_dim, hidden_dim, batch_size, n_layers=n_layers, bidirectional=False).cuda()
D = DecoderGRU(input_dim, output_dim, hidden_dim, batch_size, n_layers=n_layers, bidirectional=False).cuda()
encoder_optim = optim.Adam(E.parameters())
decoder_optim = optim.Adam(D.parameters())

dataloader = tu.create_dataloader('/home/rohanmirchandani/maxwell-pt-test/examples/', batch_size=batch_size)

In [None]:
def train(inputs, targets, encoder, decoder, encoder_optim, decoder_optim, criterion, max_length, batch_size):
    
    e_hidden = encoder.init_hidden().cuda()   
#     e_hidden = Variable(torch.zeros(2, batch_size, 1024)) # for multi gpu training
    
    encoder_optim.zero_grad()
    decoder_optim.zero_grad()
    
    z_output, e_hidden = encoder(inputs, e_hidden)
    
    d_hidden = e_hidden
    d_input = Variable(torch.zeros(1, batch_size, 50)).cuda() # <SOS> token embedding
    output_seq = list()
    loss = 0.0
    seq_length = 0
    for i in range(max_length):
        d_output, d_hidden = decoder(d_input, d_hidden) # compute next hidden state and current output token
        topv, topi = d_output.data.topk(1)
        ni = topi[0]
        idx = ni.data.cpu().numpy()[0][0]
        token = reverse_vocab[idx]
        output_seq.append(token)
        
        try:
            loss += criterion(d_output.squeeze(0), targets[i])
        except IndexError:
            print('This sample is cooked.')
            print(' '.join(output_seq))
            
        if token == '<EOS>':
            break
            
        d_input = Variable(embeddings(ni.transpose(0, 1).cpu()).float()).cuda() # <SOS> token embedding
        seq_length += 1
    
    loss.backward()
    encoder_optim.step()
    decoder_optim.step()
    
    return (loss / seq_length), output_seq

In [None]:
SOS = vocab['<SOS>']
EOS = vocab['<EOS>']
UNK = vocab['<UNK>']

print('Starting training!')
max_steps = 0
for epoch in range(epochs):
    print('EPOCH {}'.format(epoch))
    loss = 0.0
    for i, (tokens, idxs) in enumerate(dataloader):

        idxs = idxs.transpose(0, 1).long().view(-1, batch_size)
        targets = Variable(idxs).cuda()
        vectors = embeddings(idxs)
        e_input = Variable(vectors.float()).cuda()
        
        batch_loss, output_seq = train(inputs=e_input, 
                                       targets=targets, 
                                       encoder=E, 
                                       decoder=D, 
                                       encoder_optim=encoder_optim, 
                                       decoder_optim=decoder_optim, 
                                       criterion=criterion, 
                                       max_length=output_seq_len, 
                                       batch_size=batch_size)
        
        writer.add_scalar('seq2seq/loss', batch_loss.data.item(), (max_steps * epoch) + i)

        if i % 50 == 0:
            print(batch_loss)
            print(' '.join(output_seq))
    
    max_steps = i

Starting training!
EPOCH 0

 3.4657
[torch.cuda.FloatTensor of size () (GPU 0)]

orbit anxious esther esther vern subacute 430 designated designated designated snuff snuff snuff daynes daynes mcnamara mcnamara nor ost effort pacing pacing osmo osmo newnham newnham overestimation peculiar peculiar biggest 260 rinsing hre rinsing query femoris allergic joining joining 761 roper roper tail chetty missile missile vanderhorst beaded 1019 rest rest rest 2800 2800 2800 2800 passing strength tough depletion nerve fax spironolactone gym immunity 1317 1996 distorted thomason 74 dobbin luti luti gook gook graeme graeme infraction 1434 plate plate plate fuhrmann electrocardiogram give give 094 orlando orlando orlando sutures laster laster laster mainwaring lsc lsc lsc tropical bauer automated automated pearson cochlear cochlear cochlear cochlear cochlear sion rosenberger tanah tanah tanah ovaries ins ralston tenterfield tenterfield carotene carotene 482 ict ict ict saal saal infused aat aat philip


 3.0408
[torch.cuda.FloatTensor of size () (GPU 0)]

simson disturbance disturbance disturbance disturbance disturbance disturbance disturbance disturbance disturbance disturbance disturbance disturbance disturbance disturbance disturbance simson simson simson simson simson simson simson simson simson simson simson simson simson simson simson simson simson simson simson simson simson simson simson simson simson simson simson simson simson simson simson mdna mdna mdna 78 78 78 78 78 78 78 78 78 78 78 78 78 78 78 prescribed 78 prescribed 78 prescribed prescribed prescribed prescribed prescribed disturbance disturbance disturbance disturbance disturbance disturbance disturbance disturbance disturbance disturbance disturbance disturbance disturbance disturbance disturbance lumber nostril nostril nostril nostril nostril nostril nostril nostril nostril nostril nostril nostril nostril nostril nostril nostril nostril nostril nostril nostril nostril nostril lumber lumber 9300 9300 9300 9300 93


 2.7775
[torch.cuda.FloatTensor of size () (GPU 0)]

kenmore kenmore kenmore elms elms elms elms elms 106mm 106mm 106mm ventricles owing owing workstation workstation workstation workstation workstation workstation workstation workstation workstation workstation lu lu lu lu lu lu lu lu lu lu lu lu lu lu salkeld salkeld salkeld salkeld salkeld salkeld salkeld salkeld salkeld 45mm 45mm 45mm 45mm 45mm 45mm 45mm 45mm 45mm 45mm 45mm 45mm 45mm 45mm 45mm thru runs runs runs runs runs runs runs runs mediated mediated mediated mediated mediated mediated mediated splenectomy splenectomy splenectomy splenectomy splenectomy splenectomy splenectomy splenectomy splenectomy splenectomy mediated mediated mediated mediated mediated mediated mediated mediated mediated mediated mediated mediated mediated mediated mediated mediated mediated mediated mediated mediated mediated mediated mediated mediated keratinocytes s12 s12 s12 mucosal mucosal mucosal mucosal mucosal mucosal mucosal mucosal mucosal mucos


 2.7440
[torch.cuda.FloatTensor of size () (GPU 0)]

palpitations palpitations interspace interspace interspace palpitations palpitations palpitations palpitations palpitations palpitations palpitations palpitations palpitations palpitations palpitations palpitations interspace interspace interspace interspace interspace interspace interspace interspace interspace interspace interspace interspace interspace interspace interspace interspace interspace dt dt dt dt dt dt dt dt dt greer greer greer greer greer greer greer greer greer greer greer greer greer bifid bifid bifid bifid bifid bifid bifid tristania tristania tristania tristania tristania tristania tristania tristania tristania tristania tristania tristania tristania tristania tristania tristania tristania tristania tristania tristania tristania 1178 1178 1178 1178 1178 1178 1178 1178 1178 1178 1178 1178 1178 1178 1178 1178 1178 1178 1178 r5 r5 r5 r5 r5 r5 r5 r5 r5 r5 mervyn mervyn mervyn mervyn mervyn mervyn mervyn mervyn mervyn


 2.6456
[torch.cuda.FloatTensor of size () (GPU 0)]

mdna mdna kenmore kenmore kenmore kenmore kenmore kenmore elms elms elms elms elms elms elms elms elms elms elms elms elms elms elms elms elms condon 2219 2219 2219 2219 2219 2219 2219 2219 2219 2219 2219 2219 2219 2219 2219 78 operated operated operated operated operated operated operated operated operated operated operated operated repeating repeating repeating repeating repeating repeating repeating repeating repeating repeating repeating repeating mdna repeating repeating mdna craig craig mdna craig mdna craig mdna craig thru thru thru thru thru thru microvascular microvascular microvascular microvascular microvascular 1601 1601 mediated mediated mediated mediated mediated mediated mediated mediated mediated mediated mediated mediated mediated mediated lumber 9300 lumber 9300 lumber 9300 9300 9300 9300 9300 9300 9300 9300 9300 9300 9300 9300 9300 9300 9300 9300 9300 9300 9300 9300 9300 9300 9300 9300 9300 9300 9300 9300 9300 930


 2.4106
[torch.cuda.FloatTensor of size () (GPU 0)]

436 436 436 436 436 436 436 436 436 436 436 436 436 coli coli coli coli coli coli coli kenneth kenneth coli coli coli coli remarkably remarkably x8 x8 x8 x8 same craig 709 709 709 709 709 mossman mossman mossman mossman mossman simple simple simple dornbusch dornbusch dornbusch dornbusch dornbusch az argyris argyris argyris argyris 240 pocket pocket tear tear tear tear tear malabsorption hif sliding sliding sliding sliding sliding sliding hyperactive hyperactive hyperactive hyperactive kipling kipling scholtz scholtz bowen bowen bowen morrow morrow arteritis arteritis upm upm 1083 1083 1083 1083 steve steve steve irregularly arrest arrest arrest arrest racemase racemase racemase tp donna inglewood inglewood inglewood inglewood inglewood nh5 regular regular swanson swanson swanson swanson cartilage cartilage cartilage cartilage oleary authorial authorial authorial authorial authorial 1810 lamellae lamellae lamellae lamellae awaz awaz


 2.5930
[torch.cuda.FloatTensor of size () (GPU 0)]

886 steel steel steel steel steel steel steel kenmore kenmore worley worley worley worley worley worley worley worley worley worley worley worley groen groen groen groen groen groen groen groen yields yields yields yields yields yields yields yields yields yields yields yields yields yields 886 886 886 886 886 886 886 886 mucus steel steel steel steel steel steel steel steel steel steel steel steel steel steel steel steel steel steel steel steel unauthorised unauthorised unauthorised unauthorised unauthorised unauthorised unauthorised unauthorised unauthorised unauthorised unauthorised unauthorised gw gw gw gw gw gw gw gw gw gw gw managed managed managed managed managed managed managed managed managed managed managed managed managed intracerebral intracerebral intracerebral intracerebral intracerebral intracerebral mucosal mucosal mucosal mucosal steadily steadily steadily steadily steadily steadily interrupt interrupt interrupt int


 2.5978
[torch.cuda.FloatTensor of size () (GPU 0)]

farm farm loch creat pretreatment creat farm kedron kedron kedron kedron kedron kedron kedron kedron kedron kedron kedron kedron kedron kedron kedron intruding intruding intruding intruding intruding intruding intruding intruding intruding intruding intruding intruding pelvises intruding intruding intruding intruding intruding intruding intruding intruding pelvises intruding pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises pelvises mcdowell mcdowell mcdowell mcdowell mcdowell mcdowell mcdowell abrupt abrupt abrupt abrupt abrupt abrupt abrupt abrupt abrupt abrupt abrupt abrupt abrupt advisable advisable advisable advisable advisable advisable advisable advisable advisable advisable investigative investigative investi


 2.5290
[torch.cuda.FloatTensor of size () (GPU 0)]

sandgate thromboembolism 898 nother 898 898 898 898 thromboembolism thromboembolism thromboembolism thromboembolism thromboembolism merely merely merely nother merely nother nother nother nother nother nother nother nother sandgate sandgate nother nother nother nother nother nother nother nother nother nother nother nother nother nother nother nother inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalational inhalati

This sample is cooked.
cheyne kenmore fain cheyne formulation kenmore kenmore kenmore kenmore kenmore marissa marissa willem willem kiran anthracycline anthracycline anthracycline anthracycline anthracycline cheyne cheyne cheyne cheyne cheyne cheyne cheyne cheyne cheyne cheyne cheyne cheyne cheyne cheyne cheyne cheyne dimer dimer dimer dimer dimer dimer mylohyoid mylohyoid mylohyoid mylohyoid mylohyoid mylohyoid mylohyoid mylohyoid mylohyoid mylohyoid mylohyoid mylohyoid mylohyoid mylohyoid deal deal deal 45mm 45mm 45mm 45mm 45mm 45mm cheyne cheyne cheyne cheyne cheyne cheyne cheyne cheyne cheyne cheyne cheyne cheyne cheyne cheyne cheyne cheyne cheyne cheyne ionic ionic ionic ionic ionic ionic ionic ionic ionic ionic ionic ionic ionic ionic ionic ionic ionic keratinocytes keratinocytes keratinocytes keratinocytes keratinocytes wentworth wentworth wentworth wentworth wentworth wentworth wentworth wentworth wentworth suture suture suture suture suture suture suture puffy puffy puffy puff

EPOCH 1

 2.7952
[torch.cuda.FloatTensor of size () (GPU 0)]

palpitations palpitations palpitations palpitations palpitations nausea nausea nausea palpitations palpitations palpitations palpitations palpitations palpitations nausea nausea nausea palpitations palpitations palpitations palpitations palpitations palpitations palpitations palpitations palpitations proscar proscar proscar proscar proscar proscar proscar proscar proscar proscar nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen nguyen greer greer greer greer greer greer greer greer greer greer greer greer greer greer greer greer greer greer greer greer greer greer greer greer greer greer greer greer greer greer greer shape tristania tristania tristania tristania tristania tristania tristania tristania tristania tristania tristania tristania tristania tristania tristan


 2.7060
[torch.cuda.FloatTensor of size () (GPU 0)]

sheridan tristania r5 sheridan sheridan sheridan sheridan sheridan sheridan inverted inverted inverted 366 366 366 clews elongate elongate elongate elongate elongate elongate elongate elongate elongate elongate walton walton walton walton walton walton walton walton thrombolysis thrombolysis thrombolysis thrombolysis elongate elongate elongate elongate elongate elongate elongate elongate elongate centro centro centro centro nostril chan chan chan chan chan chan chan puffy puffy puffy puffy puffy puffy puffy puffy puffy puffy puffy puffy puffy puffy puffy puffy puffy puffy puffy puffy sacroiliac sacroiliac sacroiliac sacroiliac sacroiliac puffy sacroiliac puffy sacroiliac puffy sacroiliac puffy community community community community community community community community community community community bowel bowel bowel makena makena makena makena makena makena makena makena makena makena makena makena makena makena makena makena mak


 2.1097
[torch.cuda.FloatTensor of size () (GPU 0)]

436 436 482 482 couple couple couple fernando fernando fernando lane lane lane fernando fernando ovary ovary ovary thiery thiery radiating coli coli coli coli coli coli coli coli coli x8 x8 same same 709 709 709 709 mossman mossman mossman mossman simple simple dornbusch dornbusch dornbusch dornbusch az az az argyris argyris argyris 240 240 pocket pocket tear tear tear malabsorption malabsorption hif hif hif sliding sliding hyperactive hyperactive hyperactive hyperactive hyperactive hyperactive kipling scholtz scholtz scholtz bowen morrow arteritis arteritis arteritis upm upm upm upm upm 1083 1083 steve steve steve irregularly irregularly arrest arrest indentation danaher danaher cavernous cavernous cavernous donna inglewood inglewood bmm bmm displaying nh5 nh5 regular regular regular swanson swanson cartilage cartilage cartilage cartilage cartilage cartilage oleary authorial authorial authorial authorial authorial authorial authori


 2.5062
[torch.cuda.FloatTensor of size () (GPU 0)]

kiul kiul kiul kiul kiul kiul kiul kiul kiul kiul kiul kiul kiul kiul kiul kiul kiul kiul kiul kiul kiul bribie bribie bribie bribie bribie kiul kiul kiul kiul kiul kiul kiul kiul kiul kiul kiul kiul kiul something something something something kiul kiul kiul kiul kiul kiul hsu hsu hsu hsu 709 709 709 709 mossman mossman mossman mossman mossman mossman mossman mossman dornbusch dornbusch dornbusch dornbusch dornbusch dornbusch dornbusch az argyris argyris argyris argyris argyris argyris az argyris argyris argyris argyris future future future future sliding sliding awaz hyperactive hyperactive hyperactive kipling kipling kipling kipling scholtz scholtz scholtz scholtz scholtz bowen bowen bowen bowen arteritis arteritis arteritis upm upm upm upm upm upm upm steve steve steve steve steve steve arrest arrest arrest arrest arrest arrest danaher danaher cavernous cavernous cavernous cavernous cavernous cavernous cavernous affect affect af


 2.0218
[torch.cuda.FloatTensor of size () (GPU 0)]

1053 1053 1053 n78 n78 n78 final final final opinion 2031 2031 2031 2031 2031 2031 2031 2031 2031 2031 2031 709 709 709 709 709 709 mossman mossman mossman mossman simple simple simple dornbusch dornbusch dornbusch dornbusch dornbusch az az az argyris argyris 240 240 pocket pocket pocket 847 tear tear future malabsorption hif pip pip sliding sliding hyperactive hyperactive hyperactive hyperactive kipling kipling kipling cavernosa cavernosa scholtz bowen bowen morrow morrow arteritis arteritis upm upm upm 1083 steve steve steve irregularly irregularly irregularly indentation indentation indentation danaher danaher cavernous cavernous cavernous donna inglewood inglewood inglewood displaying displaying displaying regular regular dorchester dorchester swanson rt cartilage cartilage cartilage cartilage bertram bertram authorial authorial authorial authorial authorial authorial encased encased encased encased encased encased encased encas


 2.4408
[torch.cuda.FloatTensor of size () (GPU 0)]

sheridan clews clews floating floating dt cottee sheridan sheridan sheridan sheridan sheridan sheridan sheridan sheridan sheridan sheridan sheridan sheridan sheridan sheridan sheridan sheridan sheridan sheridan harts harts harts harts arumugam arumugam arumugam arumugam arumugam arumugam arumugam arumugam arumugam elongate elongate elongate elongate carroll carroll carroll carroll centro conclusively conclusively conclusively conclusively conclusively conclusively conclusively kersley kersley kersley kersley kersley kersley polypharmacy polypharmacy polypharmacy polypharmacy puffy puffy puffy puffy puffy puffy puffy puffy puffy puffy puffy puffy bpm 388 388 388 388 388 388 388 ethmoid ethmoid ethmoid ethmoid ethmoid 388 388 388 388 388 388 388 388 388 388 388 388 388 388 388 bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm bpm carisbrook c