In [1]:
import torch
from torchvision.transforms import functional as F
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from transformers import BertTokenizerFast
from torch.utils.data import DataLoader, Subset
from utils.data import COCOAEDataset, collate_fn
from utils.transforms import get_transform
from utils.transforms import ResizeTransform
from noise.scheduler import NoiseScheduler, LinearMaskScheduler, mask_image
from models.masked_autoencoder import MaskedAEConfig, MaskedAutoEncoderForPretraining, MaskedAutoEncoderForCaptioning, MaskedAutoEncoder
from tqdm import tqdm

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
original_dataset = COCOAEDataset(root="coco/images/train2017/",
                        annFile="coco/annotations/ann2017/captions_train2017.json",
                        transform=get_transform(),
                        tokenizer=BertTokenizerFast.from_pretrained('bert-base-uncased', cache_dir='cache/'),
                        ignore_cache=False,
                        train=True)


Loading cached annotations...


In [3]:
train_dataset = Subset(original_dataset, range(5))
val_dataset = Subset(original_dataset, range(100, 110))



In [40]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
train_dataloader = DataLoader(train_dataset,
                        batch_size=4,
                        shuffle=True,
                        collate_fn=collate_fn(train_dataset.dataset.tokenizer.pad_token_id),
                        pin_memory=True)
val_dataloader = DataLoader(val_dataset,
                        batch_size=4,
                        shuffle=True,
                        collate_fn=collate_fn(val_dataset.dataset.tokenizer.pad_token_id),
                        pin_memory=True)

noise_scheduler = LinearMaskScheduler(vocab_size=len(train_dataset.dataset.tokenizer), masking_ratio=0.0)



In [5]:
config = MaskedAEConfig(len(train_dataset.dataset.tokenizer))
pretrained = MaskedAutoEncoder(config).to(DEVICE)
checkpoint = torch.load("checkpoints/base_0",map_location=torch.device('cpu'))
pretrained.load_state_dict(checkpoint)

<All keys matched successfully>

In [27]:
model = MaskedAutoEncoderForCaptioning(MaskedAEConfig(len(train_dataset.dataset.tokenizer)), pretrained=pretrained).to(DEVICE)

In [7]:
optim = torch.optim.Adam(model.parameters(), lr=4e-5)
optim = torch.optim.AdamW(model.parameters(), lr=1.5e-4, betas=(0.9, 0.95), weight_decay=0.03)

image_loss = torch.nn.MSELoss()
caption_loss = torch.nn.CrossEntropyLoss()
# caption_loss = torch.nn.MSELoss()

In [52]:
for epoch in (pbar := tqdm(range(10))):
    for images, captions, lengths in train_dataloader:
        optim.zero_grad()
        images = images.to(DEVICE, non_blocking=True)
        captions = captions.to(DEVICE, non_blocking=True)
        lengths = lengths.to(DEVICE, non_blocking=True)
        # print(images.dtype, captions.dtype, lengths.dtype)

        masked_images, text, targets, (image_positions, text_pad) = noise_scheduler.get_masked(images, captions, lengths, need_masks=True)
        # print(masked_images.shape, masked_text.shape)   
        print(captions.shape, text.shape)     
        reconstructed_captions = model.forward(masked_images, captions, text_pad, image_positions)
        print(reconstructed_captions.shape)
        if epoch % 5 == 0:
            # for c in text:
            #     print("Original:", train_dataset.dataset.tokenizer.decode(c))
            for c in captions:
                print("Original:", train_dataset.dataset.tokenizer.decode(c))    
            for c in reconstructed_captions:
                print("Generated:", train_dataset.dataset.tokenizer.decode(torch.argmax(c, dim=-1)))    
            # for caption in reconstructed_captions:
            #     # print(caption.shape)
            #     values, indices = torch.topk(caption, 10)
            #     # print(values, indices)
            #     for i in indices:
            #         print(train_dataset.dataset.tokenizer.decode(i), end = ", ")      
            #     print("")                               
        # print(reconstructed_captions.shape, targets.shape)
        shifted_original = captions[:,:-1]
        shifted_reconstructed = reconstructed_captions[:,1:]
        # print(shifted_original.shape, shifted_reconstructed.shape)
        cap_loss = caption_loss(shifted_reconstructed.permute(0,2,1), shifted_original)
        # print(cap_loss)
        # cap_loss = caption_loss(reconstructed_captions, targets)
        # cap_loss = caption_loss(reconstructed_captions.permute(0, 2, 1), padded_caption)
        pbar.set_description(f"Epoch: {epoch}, Caption Loss : {cap_loss.item():1.3}")
        cap_loss.backward()
        optim.step()
        
    


  0%|          | 0/10 [00:00<?, ?it/s]

torch.Size([4, 18]) torch.Size([4, 18])


Epoch: 0, Caption Loss : 10.2:   0%|          | 0/10 [00:04<?, ?it/s]

torch.Size([4, 18, 30522])
Original: [CLS] a man with a red helmet on a small moped on a dirt road. [SEP]
Original: [CLS] a young boy barefoot holding an umbrella touching the horn of a cow [SEP] [PAD] [PAD] [PAD]
Original: [CLS] he is listening intently to the computer at school. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
Original: [CLS] a woman wearing a net on her head cutting a cake. [SEP] [PAD] [PAD] [PAD] [PAD]
Generated: on on on on on on on on on on on on on on on on on on
Generated: umbrella umbrella umbrella umbrella umbrella umbrella umbrella umbrella umbrella umbrella umbrella umbrella umbrella umbrella umbrella umbrella umbrella umbrella
Generated: [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
Generated: cake cake cake cake cake cake woman woman cake cake cake cake cake cake cake cake cake cake
torch.Size([1, 16]) torch.Size([1, 16])


Epoch: 0, Caption Loss : 10.3:   0%|          | 0/10 [00:12<?, ?it/s]

torch.Size([1, 16, 30522])
Original: [CLS] a boy wearing headphones using one computer in a long row of computers [SEP]
Generated: [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]


Epoch: 0, Caption Loss : 10.3:  10%|█         | 1/10 [00:14<02:10, 14.54s/it]

torch.Size([4, 21]) torch.Size([4, 21])


Epoch: 1, Caption Loss : 10.2:  10%|█         | 1/10 [00:17<02:10, 14.54s/it]

torch.Size([4, 21, 30522])
torch.Size([1, 12]) torch.Size([1, 12])


Epoch: 1, Caption Loss : 10.3:  10%|█         | 1/10 [00:22<02:10, 14.54s/it]

torch.Size([1, 12, 30522])


Epoch: 1, Caption Loss : 10.3:  20%|██        | 2/10 [00:24<01:35, 11.95s/it]

torch.Size([4, 18]) torch.Size([4, 18])


Epoch: 2, Caption Loss : 10.2:  20%|██        | 2/10 [00:27<01:35, 11.95s/it]

torch.Size([4, 18, 30522])
torch.Size([1, 17]) torch.Size([1, 17])


Epoch: 2, Caption Loss : 10.3:  20%|██        | 2/10 [00:32<01:35, 11.95s/it]

torch.Size([1, 17, 30522])


Epoch: 2, Caption Loss : 10.3:  30%|███       | 3/10 [00:34<01:16, 10.98s/it]

torch.Size([4, 16]) torch.Size([4, 16])


Epoch: 3, Caption Loss : 10.3:  30%|███       | 3/10 [00:36<01:16, 10.98s/it]

torch.Size([4, 16, 30522])
torch.Size([1, 17]) torch.Size([1, 17])


Epoch: 3, Caption Loss : 10.3:  30%|███       | 3/10 [00:42<01:16, 10.98s/it]

torch.Size([1, 17, 30522])


Epoch: 3, Caption Loss : 10.3:  40%|████      | 4/10 [00:44<01:03, 10.61s/it]

torch.Size([4, 17]) torch.Size([4, 17])


Epoch: 4, Caption Loss : 10.2:  40%|████      | 4/10 [00:47<01:03, 10.61s/it]

torch.Size([4, 17, 30522])
torch.Size([1, 37]) torch.Size([1, 37])


Epoch: 4, Caption Loss : 10.3:  40%|████      | 4/10 [00:52<01:03, 10.61s/it]

torch.Size([1, 37, 30522])


Epoch: 4, Caption Loss : 10.3:  50%|█████     | 5/10 [00:54<00:52, 10.42s/it]

torch.Size([4, 16]) torch.Size([4, 16])


Epoch: 5, Caption Loss : 10.2:  50%|█████     | 5/10 [00:57<00:52, 10.42s/it]

torch.Size([4, 16, 30522])
Original: [CLS] a child holding a flowered umbrella and petting a yak. [SEP]
Original: [CLS] children sitting at computer stations on a long table. [SEP] [PAD] [PAD] [PAD] [PAD]
Original: [CLS] a man riding on the back of a motorcycle. [SEP] [PAD] [PAD] [PAD] [PAD]
Original: [CLS] a woman wearing a net on her head cutting a cake. [SEP] [PAD] [PAD]
Generated: cow cow cow cow cow cow cow cow cow cow cow cow cow cow cow cow
Generated: [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] blast blast [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
Generated: on a a on a on on a a a a on on on on on
Generated: cake cake woman cake woman woman cake woman cake woman cake woman woman woman cake cake
torch.Size([1, 14]) torch.Size([1, 14])


Epoch: 5, Caption Loss : 10.3:  50%|█████     | 5/10 [01:02<00:52, 10.42s/it]

torch.Size([1, 14, 30522])
Original: [CLS] a little boy wearing headphones and looking at a computer monitor [SEP]
Generated: [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]


Epoch: 5, Caption Loss : 10.3:  60%|██████    | 6/10 [01:04<00:40, 10.21s/it]

torch.Size([4, 37]) torch.Size([4, 37])


Epoch: 6, Caption Loss : 10.2:  60%|██████    | 6/10 [01:07<00:40, 10.21s/it]

torch.Size([4, 37, 30522])
torch.Size([1, 15]) torch.Size([1, 15])


Epoch: 6, Caption Loss : 10.3:  60%|██████    | 6/10 [01:13<00:40, 10.21s/it]

torch.Size([1, 15, 30522])


Epoch: 6, Caption Loss : 10.3:  70%|███████   | 7/10 [01:15<00:31, 10.35s/it]

torch.Size([4, 18]) torch.Size([4, 18])


Epoch: 7, Caption Loss : 10.2:  70%|███████   | 7/10 [01:17<00:31, 10.35s/it]

torch.Size([4, 18, 30522])
torch.Size([1, 13]) torch.Size([1, 13])


Epoch: 7, Caption Loss : 10.3:  70%|███████   | 7/10 [01:24<00:31, 10.35s/it]

torch.Size([1, 13, 30522])


Epoch: 7, Caption Loss : 10.3:  80%|████████  | 8/10 [01:25<00:21, 10.51s/it]

torch.Size([4, 37]) torch.Size([4, 37])


Epoch: 8, Caption Loss : 10.2:  80%|████████  | 8/10 [01:28<00:21, 10.51s/it]

torch.Size([4, 37, 30522])
torch.Size([1, 13]) torch.Size([1, 13])


Epoch: 8, Caption Loss : 10.3:  80%|████████  | 8/10 [01:34<00:21, 10.51s/it]

torch.Size([1, 13, 30522])


Epoch: 8, Caption Loss : 10.3:  90%|█████████ | 9/10 [01:36<00:10, 10.43s/it]

torch.Size([4, 21]) torch.Size([4, 21])


Epoch: 9, Caption Loss : 10.2:  90%|█████████ | 9/10 [01:38<00:10, 10.43s/it]

torch.Size([4, 21, 30522])
torch.Size([1, 15]) torch.Size([1, 15])


Epoch: 9, Caption Loss : 10.3:  90%|█████████ | 9/10 [01:44<00:10, 10.43s/it]

torch.Size([1, 15, 30522])


Epoch: 9, Caption Loss : 10.3: 100%|██████████| 10/10 [01:46<00:00, 10.61s/it]


In [56]:
for text in captions:
    print(train_dataset.dataset.tokenizer.decode(text))
for caption in reconstructed_captions:
    print(train_dataset.dataset.tokenizer.decode(torch.argmax(caption, dim=-1)))
# for caption in reconstructed_captions:
#     # print(caption.shape)
#     values, indices = torch.topk(caption, 10)
#     # print(values, indices)
#     for i in indices:
#         print(train_dataset.dataset.tokenizer.decode(i), end=", ")


[CLS] a young boy barefoot holding an umbrella touching the horn of a cow [SEP]
[PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]


: 

In [None]:
model.eval() 
total_loss = 0.0

with torch.no_grad():
    for images, captions, lengths in val_dataloader:
        images = images.to(DEVICE, non_blocking=True)
        captions = captions.to(DEVICE, non_blocking=True)
        masked_images, masked_text, targets, (image_positions, rp) = noise_scheduler.get_masked(images, captions, lengths, need_masks=True)
        print(targets)
        for image in images:
            plt.imshow(image.permute(1, 2, 0).detach().cpu().numpy())
            plt.show() 
        reconstructed = model.forward(masked_images, lengths)
        for caption in reconstructed:
            values, indices = torch.topk(caption, 10)
            for i in indices:
                print(val_dataset.dataset.tokenizer.decode(i))
        cap_loss = caption_loss(reconstructed, targets)
        total_loss += cap_loss.item() * images.size(0)

avg_loss = total_loss / len(val_dataset)
print(f"Validation Loss: {avg_loss:.3f}")


In [None]:
test = DataLoader(train_dataset,
                        batch_size=1,
                        shuffle=True,
                        collate_fn=collate_fn(train_dataset.dataset.tokenizer.pad_token_id),
                        pin_memory=True)
for images, captions, lengths in test:
    optim.zero_grad()
    images = images.to(DEVICE, non_blocking=True)
    captions = captions.to(DEVICE, non_blocking=True)
    lengths = lengths.to(DEVICE, non_blocking=True)
    # print(images.dtype, captions.dtype, lengths.dtype)

    masked_images, text, targets, (image_positions, text_pad) = noise_scheduler.get_masked(images, captions, lengths, need_masks=True)
    print(text_pad)
    print()