In [1]:
from preprocess import read_json, create_caption_dict, create_video_objects, clean_caption, gather_text, get_transforms
from preprocess import Video, VideoDataset, create_vocab, get_caption_dict
from base import EncoderBase, DecoderBase, Seq2SeqBase
from torch.utils.data import DataLoader, Dataset
import torch
import random
from preprocess import *
from base import *
from gru_summ_model import SummarizationModel

ModuleNotFoundError: No module named 'preprocess'

In [2]:
data = read_json()

In [3]:
captions_dict = get_caption_dict()

In [4]:
caption_text = gather_text(captions_dict)

In [6]:
vocab, vocab_to_int, int_to_vocab = create_vocab(caption_text)

In [7]:
train_video_objects = create_video_objects(0, 16, data)

In [8]:
transform = get_transforms()

In [9]:
train_data = VideoDataset(train_video_objects, vocab_to_int, transform)

In [10]:
train_loader = DataLoader(train_data, batch_size=4, shuffle=False)

In [11]:
valid_video_objects = create_video_objects(20, 36, data)

In [12]:
valid_data = VideoDataset(valid_video_objects, vocab_to_int, transform)

In [13]:
valid_loader = DataLoader(valid_data, batch_size=4, shuffle=False)

In [14]:
INPUT_DIM = 1024
OUTPUT_DIM = len(vocab) #len(vocab) #import vocab here
ENCODER_HID_DIM = 512
DECODER_HID_DIM = 512
EMBEDDING_DIM = 256
DROPOUT = 0.4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [15]:
encoder = EncoderBase(hidden_dim=ENCODER_HID_DIM, input_dim=INPUT_DIM, device=device)
decoder = DecoderBase(EMBEDDING_DIM, DECODER_HID_DIM, OUTPUT_DIM, DROPOUT)



model = Seq2SeqBase(encoder, decoder, device).to(device)

In [16]:
model

Seq2SeqBase(
  (encoder): EncoderBase(
    (gru): GRU(1024, 512, batch_first=True)
  )
  (decoder): DecoderBase(
    (embedding): Embedding(28790, 256)
    (gru): GRU(256, 512)
    (linear): Linear(in_features=512, out_features=28790, bias=True)
    (dropout): Dropout(p=0.4, inplace=False)
  )
)

In [17]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)
        
model.apply(init_weights)

def get_opt_loss():
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    
    return optimizer, criterion

In [18]:
def count_parameters(model):
    parameters_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return parameters_count

In [19]:
print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 25,684,598 trainable parameters


In [20]:
def train(model, loader, opt, criterion, clip):
    
    model.train()
    epoch_loss = 0.0
    
    for frames, captions in loader:
        
        # frames = [bs, seq_len, 3, 224, 224] = [bs, 32, 3, 224, 224]
        # captions = [bs, num_captions, seq_len] = [bs, 20, 20]
        
        #captions = captions.permute(1, 2, 0).type(torch.LongTensor)
         # captions = [num_captions, seq_len, bs] = [20, 20, bs]
        
        frames = frames.to(device)
        captions = captions.to(device)
       
        
        for i in range(captions.shape[1]):
            
            opt.zero_grad()
            
            output = model(frames, captions[:,i,:].long())
            print('output1: ', output.shape)
            # output [trg_len, bs, output_vocab_dim]
            # captions[i] = [trg_len, bs]
            
            output_dim = output.shape[-1]
            
            output = output[1:].view(-1, output_dim)
            # output = [(trg_len-1)*bs, output_dim]
            print('output2: ', output.shape)
            target = captions[:,i,:].T[1:].reshape(-1)
            # target = [(trg_len-1)*bs]
            print('target: ', target.shape)
            loss = criterion(output, target.long())
            
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
            opt.step()
        
            epoch_loss += loss.item()
        
    
    #print(epoch_loss/len(loader))
    return epoch_loss / len(loader)
            

In [21]:
def evaluate(model, loader, criterion):
    
    model.train()
    epoch_loss = 0.0
    
    for frames, captions in loader:
        
        # frames = [bs, seq_len, 3, 224, 224] = [bs, 32, 3, 224, 224]
        # captions = [bs, num_captions, seq_len] = [bs, 20, 20]
        
        #captions = captions.permute(1, 2, 0).type(torch.LongTensor)
         # captions = [num_captions, seq_len, bs] = [20, 20, bs]
        
        frames = frames.to(device)
        captions = captions.to(device)
       
        with torch.no_grad():
            for i in range(captions.shape[1]):

               

                output = model(frames, captions[:,i,:].long())

                # output [trg_len, bs, output_vocab_dim]
                # captions[i] = [trg_len, bs]

                output_dim = output.shape[-1]

                output = output[1:].view(-1, output_dim)
                # output = [(trg_len-1)*bs, output_dim]

                target = captions[:,i,:].T[1:].reshape(-1)
                # target = [(trg_len-1)*bs]

                loss = criterion(output, target.long())
                
                epoch_loss += loss.item()


    #print(epoch_loss/len(loader))
    return epoch_loss / len(loader)
            

In [22]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [23]:
import time, random
N_EPOCHS = 1
CLIP = 1
optimizer, criterion = get_opt_loss()
best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_loader, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_loader, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'best_model.pt')
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f}')

output1:  torch.Size([20, 4, 28790])
output2:  torch.Size([76, 28790])
target:  torch.Size([76])
output1:  torch.Size([20, 4, 28790])
output2:  torch.Size([76, 28790])
target:  torch.Size([76])
output1:  torch.Size([20, 4, 28790])
output2:  torch.Size([76, 28790])
target:  torch.Size([76])
output1:  torch.Size([20, 4, 28790])
output2:  torch.Size([76, 28790])
target:  torch.Size([76])
output1:  torch.Size([20, 4, 28790])
output2:  torch.Size([76, 28790])
target:  torch.Size([76])
output1:  torch.Size([20, 4, 28790])
output2:  torch.Size([76, 28790])
target:  torch.Size([76])
output1:  torch.Size([20, 4, 28790])
output2:  torch.Size([76, 28790])
target:  torch.Size([76])
output1:  torch.Size([20, 4, 28790])
output2:  torch.Size([76, 28790])
target:  torch.Size([76])
output1:  torch.Size([20, 4, 28790])
output2:  torch.Size([76, 28790])
target:  torch.Size([76])
output1:  torch.Size([20, 4, 28790])
output2:  torch.Size([76, 28790])
target:  torch.Size([76])
output1:  torch.Size([20, 4, 2

KeyboardInterrupt: 

In [24]:
x = torch.randn(20, 32, 100)

In [25]:
x[1:].shape

torch.Size([19, 32, 100])