In [3]:
import torch

from transformers import GPT2Model, GPT2Config, GPT2Tokenizer

class GPT2WithIntermediateOutputs(GPT2Model):

    def __init__(self, config):

        super().__init__(config)

    def forward(self, input_ids, attention_mask=None):

        # Get embeddings from input

        input_shape = input_ids.size()

        input_ids = input_ids.view(-1, input_shape[-1])

        device = input_ids.device

        if attention_mask is None:

            attention_mask = torch.ones(input_shape, device=device)

        # Prepare attention mask

        if attention_mask.dim() == 3:

            extended_attention_mask = attention_mask[:, None, :, :]

        elif attention_mask.dim() == 2:

            extended_attention_mask = attention_mask[:, None, None, :]

        # Prepare head mask if needed

        head_mask = self.get_head_mask(None, self.config.n_layer)

        # Transformer layers

        hidden_states = self.wte(input_ids) + self.wpe(torch.arange(0, input_shape[-1], device=device))

        hidden_states = self.drop(hidden_states)

        output_shape = input_shape + (hidden_states.size(-1),)

        all_hidden_states = torch.empty((*output_shape, self.config.n_layer + 1), device=device)

        all_hidden_states[..., 0] = hidden_states

        for i, (block, layer_past) in enumerate(zip(self.h, [None]*len(self.h))):

            outputs = block(hidden_states, layer_past=layer_past, attention_mask=extended_attention_mask, head_mask=head_mask[i])

            hidden_states = outputs[0]

            all_hidden_states[..., i+1] = hidden_states

        # Concatenate all hidden states

        concatenated_outputs = all_hidden_states.permute(2, 0, 1, 3).reshape(self.config.n_layer + 1, -1, hidden_states.size(-1))

        return concatenated_outputs
 

# Example usage:

config = GPT2Config.from_pretrained("gpt2")

model = GPT2WithIntermediateOutputs(config)

print(model)

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

input_ids = torch.tensor([tokenizer.encode("if the vocabulary is known, then the sequence length is correct")])

# Output will have dimensions: [num_layers + 1, batch_size, seq_length, features] (discard element 0 , that's the embedding initial layer)

outputs = model(input_ids)


print(outputs.shape)  # Output dimensions

GPT2WithIntermediateOutputs(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
torch.Size([13, 12, 768])


In [5]:
import pickle

config = GPT2Config.from_pretrained('gpt2', output_hidden_states=True)
model = GPT2WithIntermediateOutputs(config)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Example data
data = [
    ("Hello, world!", 1),
    ("This is a test.", 0),
    ("This isn't a test.", 1),
    ("Another much longer test", 1)
]

# Process each sentence
processed_data = []
for text, label in data:
    input_ids = tokenizer.encode(text, return_tensors='pt')
    with torch.no_grad():
        outputs = model(input_ids)
        print(outputs.shape)
       

    processed_data.append((outputs, label))

# Save the processed data
with open('processed_data.pkl', 'wb') as f:
    pickle.dump(processed_data, f)

torch.Size([13, 4, 768])
torch.Size([13, 5, 768])
torch.Size([13, 6, 768])
torch.Size([13, 4, 768])


In [7]:
import torch.nn as nn
from gaussian_adaptive_attention import MultiHeadGaussianAdaptiveAttention

class ClassifierWithGAAM(nn.Module):
    def __init__(self, num_classes, num_gaussians, norm_axis):
        super().__init__()
        self.gaam = MultiHeadGaussianAdaptiveAttention(
            norm_axis=norm_axis,  # Choose either -2 (sequence length) or 0 (layer number)
            num_heads=5,
            num_gaussians=num_gaussians,
            padding_value=0
        )
        self.conv1 = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv2 = nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        self.fc = nn.Linear(64, num_classes)  # This needs to be adjusted based on output size

    def forward(self, x):
        x = self.gaam(x.unsqueeze(0))  # Add channel dimension
        x = self.conv1(x)
        x = nn.ReLU()(x)
        x = self.conv2(x)
        x = nn.ReLU()(x)
        x = nn.AdaptiveAvgPool2d((1, 1))(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# Initialize classifier
classifier = ClassifierWithGAAM(num_classes=2, num_gaussians=3, norm_axis=-2)


In [12]:
import torch.optim as optim

# Load data
with open('processed_data.pkl', 'rb') as f:
    processed_data = pickle.load(f)

# Setup optimizer and loss function
optimizer = optim.Adam(classifier.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Training classifier
for x, label in processed_data:
    flattened = x.view(1,-1,768)
    output = classifier(flattened)
    print(f"Predicted: {torch.argmax(output)}, Actual: {label}")


Predicted: 1, Actual: 1
Predicted: 1, Actual: 0
Predicted: 1, Actual: 1
Predicted: 1, Actual: 1
