In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
import torch.nn as nn

_include_('curriculum_vqa')

seed = 1
np.random.seed(seed)
torch.manual_seed(seed);

from cvqa import datasets, models, trainers, viz

In [3]:
from cvqa.curriculum import VQAInstanceDistribution

In [4]:
from fairseq.data import Dictionary
import torch.nn.functional as F

samples = []
for s in VQAInstanceDistribution().sample_dataset(images=10, prompts_per_image=5):
    samples.append({
        'prompt': s['prompt'] + ' ans = ' + s['target'],
        'target': s['target']
    })
    
dataset = datasets.SimpleDataset(samples)
vocab = dataset.vocab

In [5]:
from cvqa import fairseq_misc
model = fairseq_misc.build_transformer(vocab)

In [6]:
from torch import nn

class MyTransformer(nn.Module):
    
    def __init__(self, enc, dec):
        super().__init__()
        self.enc = enc
        self.dec = dec
        
    def forward(self, src_tokens, src_lengths, prev_output_tokens):
        encoder_out = self.enc(src_tokens, None)
        decoder_out = self.dec(prev_output_tokens, encoder_out=encoder_out)
        return decoder_out
    
tokens_embed = fairseq_misc.build_embedding(vocab, 16)
encoder = fairseq_misc.build_vqa_encoder(vocab, tokens_embed)
decoder = fairseq_misc.build_decoder(vocab, tokens_embed)
model = MyTransformer(encoder, decoder)

In [7]:
from tqdm import tqdm 
import statistics

def model_forward(model, sample):
    src_tokens = sample['prompt']
    targets = sample['target']
    src_lengths = None
    
    B = src_tokens.shape[0]
    prev_output_tokens = torch.zeros(B, 1, dtype=torch.int64)
    model_out = model(src_tokens, src_lengths, prev_output_tokens)
    logits = model_out[0]
    
    logits = logits.view(-1, logits.size(-1)) 
    targets = targets.flatten()
    return logits, targets


def evaluate(model, dataset, ignore_index=None, iter_lim=None):
    dloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
    model.eval()
    with torch.set_grad_enabled(False):
        correct = 0
        total = 0
        for i, sample in enumerate(dloader):
            if iter_lim is not None and i >= iter_lim:
                break

            logits, y_true = model_forward(model, sample)

            _, y_pred = torch.max(logits.data, -1)

            if ignore_index is not None:
                mask = y_true.ne(self.ignore_index)
                y_true = y_true[mask]
                y_pred = y_pred[mask]

            correct += (y_pred == y_true).sum()
            total += y_true.size(0)

        return float(correct) / float(total)
    
    
def train(model, dataset, optim, num_epochs=2):
    dloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
    train_loss = []
    with tqdm(range(num_epochs)) as prg_train:
        for epoch in prg_train:
            for sample in dloader:
                model.train()
                optim.zero_grad()

                logits, targets = model_forward(model, sample)
                loss = crit(logits, targets)

                loss.backward()
                optim.step()

                train_loss.append(loss.item())
                running_mean_loss = statistics.mean(train_loss[-min(len(train_loss), 100):])
                status_str = f'[epoch={epoch}] loss: {running_mean_loss:.3f}'
                prg_train.set_description(status_str)


In [8]:
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
crit = nn.CrossEntropyLoss()
train(model, dataset, optim, num_epochs=250)

[epoch=249] loss: 0.639: 100%|██████████| 250/250 [00:07<00:00, 35.17it/s]


In [9]:
evaluate(model, dataset)

0.96