todo
- handle .eval, .train correctly (models may use dropout)

- Conv1D: https://github.com/huggingface/transformers/blob/main/src/transformers/pytorch_utils.py
- GPT2 hf code: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py

# GPT-2 direct

In [1]:
import transformers
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM

from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from typing import Optional, Tuple, Union
import torch

import transformers.models.gpt2.modeling_gpt2
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP

In [2]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModel.from_pretrained("gpt2")
gpt_config = model.config
gpt_config.add_cross_attention = True
gpt_config.is_decoder = True
model = AutoModelForCausalLM.from_pretrained("gpt2", config=gpt_config).cuda()

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.4.crossattention.q_attn.weight', 'h.4.crossattention.bias', 'h.2.crossattention.bias', 'h.6.crossattention.c_attn.weight', 'h.4.crossattention.c_proj.bias', 'h.1.crossattention.c_proj.weight', 'h.3.crossattention.c_proj.weight', 'h.5.crossattention.c_attn.weight', 'h.1.ln_cross_attn.weight', 'h.8.crossattention.c_attn.weight', 'h.8.crossattention.c_proj.weight', 'h.0.crossattention.c_proj.weight', 'h.10.crossattention.bias', 'h.10.crossattention.c_proj.weight', 'h.5.ln_cross_attn.weight', 'h.7.crossattention.bias', 'h.10.crossattention.c_attn.weight', 'h.11.crossattention.masked_bias', 'h.6.ln_cross_attn.weight', 'h.11.crossattention.c_proj.bias', 'h.7.ln_cross_attn.weight', 'h.9.crossattention.c_attn.weight', 'h.3.crossattention.q_attn.weight', 'h.8.crossattention.bias', 'h.4.crossattention.masked_bias', 'h.7.crossattention.masked_bias', 'h.6.crossattention.bias', 'h.5

In [4]:
sum(p.numel() for p in model.transformer.h[0].crossattention.parameters())

2362368

In [134]:
class GPT2BlockPatched(nn.Module):
    def __init__(self, config, layer_idx=None):
        super().__init__()
        
        print('Initializing GPT2BlockPatched with config:', config)
        
        hidden_size = config.hidden_size
        inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.attn = GPT2Attention(config, layer_idx=layer_idx)
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

        if config.add_cross_attention:
            import copy
            cross_attention_config = copy.deepcopy(config)
            #cross_attention_config.hidden_size = 24
            self.crossattention = GPT2Attention(cross_attention_config, is_cross_attention=True, layer_idx=layer_idx)
            self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = GPT2MLP(inner_dim, config)

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_outputs = self.attn(
            hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)
        outputs = attn_outputs[1:]
        # residual connection
        hidden_states = attn_output + residual

        if encoder_hidden_states is not None:
            # add one self-attention block for cross-attention
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
                    "cross-attention layers by setting `config.add_cross_attention=True`"
                )
            residual = hidden_states
            hidden_states = self.ln_cross_attn(hidden_states)
            cross_attn_outputs = self.crossattention(
                hidden_states,
                attention_mask=attention_mask,
                head_mask=head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                output_attentions=output_attentions,
            )
            attn_output = cross_attn_outputs[0]
            # residual connection
            hidden_states = residual + attn_output
            outputs = outputs + cross_attn_outputs[2:]  # add cross attentions if we output attention weights

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states

        if use_cache:
            outputs = (hidden_states,) + outputs
        else:
            outputs = (hidden_states,) + outputs[1:]

        return outputs  # hidden_states, present, (attentions, cross_attentions)
    
transformers.models.gpt2.modeling_gpt2.GPT2Block = GPT2BlockPatched

In [138]:
n_params = 0
for x in model.transformer.h:
    for x in list(x.crossattention.parameters()):
        #print(x.shape)
        n_params += x.numel()
n_params

28348416

In [11]:
prompt = tokenizer.batch_encode_plus(['www.'], return_tensors='pt')#, padding='longest')

In [12]:
# encoder_batch_size, encoder_sequence_length, _

result = model.forward(
    input_ids=prompt['input_ids'].cuda(),
    encoder_hidden_states=torch.zeros(1, 1, 768).cuda()
)
result.logits.sum()

tensor(-4802594.5000, device='cuda:0', grad_fn=<SumBackward0>)

In [16]:
result

tensor([[ 2503,    13, 11604,    13,   785,    14,  8340,    30,    85,    28,
            55,    80,    55,    80,    55,    80,    55,    80,    55,    48,
            19,   198,   198,  4023,  1378,  2503,    13, 11604,    13,   785,
            14,  8340,    30,    85,    28,    55,    80,    55,    80,    55,
            80,    55,    80,    55,    48,    19,   198,   198,  4023,  1378,
          2503,    13, 11604,    13,   785,    14,  8340,    30,    85,    28,
            55,    80,    55,    80]], device='cuda:0')

In [15]:
result = model.generate(inputs=prompt['input_ids'].cuda(), attention_mask=prompt['attention_mask'].cuda(), max_length=64, encoder_hidden_states=100*torch.ones(1, 1, 768).cuda())
tokenizer.batch_decode(result)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['www.youtube.com/watch?v=XqXqXqXqXQ4\n\nhttp://www.youtube.com/watch?v=XqXqXqXqXQ4\n\nhttp://www.youtube.com/watch?v=XqXq']

# Load musiccaps

In [1]:
from musiccaps import load_musiccaps
import numpy as np
from rich import print as printr

In [2]:
ds = load_musiccaps(
    './music_data',
    sampling_rate=16000,
    limit=None,
    num_proc=8,
    writer_batch_size=1000,
    return_without_audio=True
)
embeddings = np.load('embeddings.npy', allow_pickle=True).item()

Using custom data configuration google--MusicCaps-7925612b943f961b
Found cached dataset csv (/home/dominik/.cache/huggingface/datasets/google___csv/google--MusicCaps-7925612b943f961b/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


# Image captioning

In [3]:
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
import transformers
import torch
from PIL import Image
import matplotlib.pyplot as plt
import torchinfo
import torch.nn as nn
from tqdm.auto import tqdm
import itertools

In [26]:
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning").cuda()
feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

encoder_forward = model.encoder.forward



In [51]:
def create_batcher(bs):
    for epoch in itertools.count(0, 1):
        captions, embs = [], []
        
        for i in np.random.permutation(len(ds)):
            i = int(i)
            try:
                cap = ds[i]['caption']
                emb = embeddings[ds[i]['ytid']]
            except Exception as e:
                continue
                
            captions.append(cap + ' <|endoftext|>')
            embs.append(emb)
            
            if len(captions) == bs:
                assert len(embs) == bs
                
                captions_tok = tokenizer(captions)['input_ids']
                captions_tok = [torch.tensor(t) for t in captions_tok]
                captions_tok = torch.nn.utils.rnn.pad_sequence(captions_tok, batch_first=True, padding_value=-100).cuda()
                embs = torch.from_numpy(np.stack(embs)).cuda()
                yield captions, captions_tok, embs, epoch
                captions, embs = [], []
        print(f'Finished {epoch+1} epochs')

In [64]:
class B2T(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.main = nn.Sequential(
            nn.Linear(512, 768),
            nn.ReLU(),
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Linear(768, 768)
        )
        
    def forward(self, x):
        return self.main(x)
    
b2t = B2T().cuda()

cross_attention_params = []
for x in model.decoder.transformer.h:
    cross_attention_params.extend(x.crossattention.parameters())

opt = torch.optim.AdamW([
    {'params': b2t.parameters(), 'lr': 0.0001},
    {'params': cross_attention_params, 'lr': 0.000001, 'weight_decay': 0}
])

losses = []
bs = 32
batcher = create_batcher(bs)

In [65]:
def patched_forward(*args, **kwargs):
    #result = encoder_forward(*args, **kwargs)
    result = transformers.modeling_outputs.BaseModelOutputWithPooling(
        last_hidden_state=b2t(embs).unsqueeze(1)
    )
    #result.last_hidden_state = b2t(embs).unsqueeze(1)
    return result
model.encoder.forward = patched_forward

In [84]:
prompt = tokenizer.batch_encode_plus(['Attributes that describe this song:'], return_tensors='pt')['input_ids'].cuda()

In [91]:
model.decoder.forward(prompt)

CausalLMOutputWithCrossAttentions(loss=None, logits=tensor([[[-36.2390, -26.7781, -27.6222,  ..., -48.4080, -43.9808, -34.2555],
         [-56.4835, -54.6282, -58.1651,  ..., -64.9682, -54.9985, -54.3861],
         [-42.1242, -40.6423, -45.6622,  ..., -49.1798, -44.5205, -40.5976],
         [-48.5948, -46.8968, -53.2402,  ..., -56.6921, -49.9241, -46.9453],
         [-42.9220, -41.5556, -46.1100,  ..., -52.7577, -47.9151, -41.5438],
         [-28.3301, -24.8079, -27.8425,  ..., -35.5392, -30.3479, -21.8031]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>), past_key_values=((tensor([[[[-1.0436,  1.8444,  0.5351,  ..., -0.0334, -0.3260,  0.9289],
          [-2.1881,  1.6063,  1.0011,  ..., -0.8330, -1.2624,  1.8925],
          [-1.4904,  1.3848,  0.8831,  ..., -0.9845, -1.4762,  1.1756],
          [-1.7415,  1.7767,  1.7490,  ..., -0.9805, -1.6425,  1.6686],
          [-2.1586,  1.9328,  0.2944,  ..., -1.2616, -2.2407,  0.5414],
          [-1.9061,  1.5793,  1.9901,  ..., -1.194

In [89]:
result = model.decoder.generate(prompt)



In [90]:
tokenizer.batch_decode(result)

['Attributes that describe this song: "the sun is shining on the man in the hat <|endoftext|>']

In [67]:
grad_accum = 2
print('Effective batch size:', grad_accum*bs)
for step in tqdm(itertools.count(0, 1)):
    captions, captions_tok, embs, epoch = next(batcher)

    loss = model(1, labels=captions_tok).loss
    
    loss.backward()
    if step % grad_accum == 0:
        opt.step()
        opt.zero_grad()
    
    losses.append(loss.item())
    
    if step % 20 == 0:
        embs = embs[0:1]
        output_ids = model.generate(None, max_length=256, num_beams=2)
        printr('[blue bold] PREDICTION1: ' + tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip())
        output_ids = model.generate(None, max_length=256, num_beams=4, do_sample=True, temperature=0.8)
        printr('[blue bold] PREDICTION2: ' + tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip())
        printr('[green bold] TRUE CAPTION: ' + captions[0])
        print()
    
    if step % 200 == 199:
        plt.plot(losses)
        plt.show()

Effective batch size: 64


0it [00:00, ?it/s]










KeyboardInterrupt: 