In [1]:
import torch

from transformers import GPT2Model, GPT2Config, GPT2Tokenizer

class GPT2WithIntermediateOutputs(GPT2Model):

    def __init__(self, config):

        super().__init__(config)

    def forward(self, input_ids, attention_mask=None):

        # Get embeddings from input

        input_shape = input_ids.size()

        input_ids = input_ids.view(-1, input_shape[-1])

        device = input_ids.device

        if attention_mask is None:

            attention_mask = torch.ones(input_shape, device=device)

        # Prepare attention mask

        if attention_mask.dim() == 3:

            extended_attention_mask = attention_mask[:, None, :, :]

        elif attention_mask.dim() == 2:

            extended_attention_mask = attention_mask[:, None, None, :]

        # Prepare head mask if needed

        head_mask = self.get_head_mask(None, self.config.n_layer)

        # Transformer layers

        hidden_states = self.wte(input_ids) + self.wpe(torch.arange(0, input_shape[-1], device=device))

        hidden_states = self.drop(hidden_states)

        output_shape = input_shape + (hidden_states.size(-1),)

        all_hidden_states = torch.empty((*output_shape, self.config.n_layer + 1), device=device)

        all_hidden_states[..., 0] = hidden_states

        for i, (block, layer_past) in enumerate(zip(self.h, [None]*len(self.h))):

            outputs = block(hidden_states, layer_past=layer_past, attention_mask=extended_attention_mask, head_mask=head_mask[i])

            hidden_states = outputs[0]

            all_hidden_states[..., i+1] = hidden_states

        # Concatenate all hidden states

        concatenated_outputs = all_hidden_states.permute(2, 0, 1, 3).reshape(self.config.n_layer + 1, -1, hidden_states.size(-1))

        return concatenated_outputs
 

# Example usage:

config = GPT2Config.from_pretrained("gpt2")

model = GPT2WithIntermediateOutputs(config)

print(model)

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

input_ids = torch.tensor([tokenizer.encode("if the vocabulary is known, then the sequence length is correct")])

# Output will have dimensions: [num_layers + 1, batch_size, seq_length, features] (discard element 0 , that's the embedding initial layer)

outputs = model(input_ids)


print(outputs.shape)  # Output dimensions

  from .autonotebook import tqdm as notebook_tqdm


GPT2WithIntermediateOutputs(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
torch.Size([13, 12, 768])


In [2]:
import pickle

config = GPT2Config.from_pretrained('gpt2', output_hidden_states=True)
model = GPT2WithIntermediateOutputs(config)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Example data
data = [
    ("Hello, world!", 1),
    ("This is a test.", 0),
    ("This isn't a test.", 1),
    ("Another much longer test", 1)
]

# Process each sentence
processed_data = []
for text, label in data:
    input_ids = tokenizer.encode(text, return_tensors='pt')
    with torch.no_grad():
        outputs = model(input_ids)
        print(outputs.shape)
       

    processed_data.append((outputs, label))

# Save the processed data
with open('processed_data.pkl', 'wb') as f:
    pickle.dump(processed_data, f)

torch.Size([13, 4, 768])
torch.Size([13, 5, 768])
torch.Size([13, 6, 768])
torch.Size([13, 4, 768])


In [3]:
import torch.nn as nn
from gaussian_adaptive_attention import MultiHeadGaussianAdaptiveAttention

class ClassifierWithGAAM(nn.Module):
    def __init__(self, num_classes, num_gaussians, norm_axis):
        super().__init__()
        self.gaam = MultiHeadGaussianAdaptiveAttention(
            norm_axis=norm_axis,  # Choose either -2 (sequence length) or 0 (layer number)
            num_heads=5,
            num_gaussians=num_gaussians,
            padding_value=0
        )
        self.conv1 = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv2 = nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        self.fc = nn.Linear(64, num_classes)  # This needs to be adjusted based on output size

    def forward(self, x):
        x = self.gaam(x.unsqueeze(0))  # Add channel dimension
        x = self.conv1(x)
        x = nn.ReLU()(x)
        x = self.conv2(x)
        x = nn.ReLU()(x)
        x = nn.AdaptiveAvgPool2d((1, 1))(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# Initialize classifier
classifier = ClassifierWithGAAM(num_classes=4, num_gaussians=3, norm_axis=-2)


In [4]:
import torch
import torch.nn as nn
from gaussian_adaptive_attention import MultiHeadGaussianAdaptiveAttention

class GAAAttentionClassifier(nn.Module):
    def __init__(self, input_dim, num_classes, num_heads, num_gaussians):
        super(GAAAttentionClassifier, self).__init__()
        # Gaussian Adaptive Attention layer
        self.gaam = MultiHeadGaussianAdaptiveAttention(
            num_heads=num_heads,
            num_gaussians=num_gaussians,
            norm_axis=-2,  # Typically, we normalize over the sequence length dimension
            padding_value=0
        )
        # Fully connected layers
        self.fc1 = nn.Linear(input_dim, 512)  # Adjust input_dim to match the output of gaam if necessary
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        # x shape might be [batch_size, seq_length, feature_dim]
        # Apply Gaussian Adaptive Attention
        x = self.gaam(x)
        # Assuming we pool over the sequence length dimension
        x = torch.mean(x, dim=1)  # Reduce over the sequence length
        # Pass through fully connected layers
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Parameters
input_dim = 768  # Assumed dimension of each vector in the sequence
num_classes = 4  # Number of target classes
num_heads = 12    # Number of attention heads
num_gaussians = 3  # Number of Gaussian distributions in the attention mechanism

# Create the classifier
gaam_classifier = GAAAttentionClassifier(input_dim, num_classes, num_heads, num_gaussians)


In [5]:
import torch
import torch.nn as nn

class GPT2Classifier(nn.Module):
    def __init__(self, hidden_dim, num_classes):
        super(GPT2Classifier, self).__init__()
        self.fc1 = nn.Linear(hidden_dim, 512)  
        self.relu = nn.ReLU()  
        self.fc2 = nn.Linear(512, 128)  
        self.dropout = nn.Dropout(0.1) 
        self.fc3 = nn.Linear(128, num_classes)  
    def forward(self, x):
        x = torch.mean(x, dim=1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc3(x)
        return x

hidden_dim = 768 
num_classes = 4  

basic_classifier = GPT2Classifier(hidden_dim, num_classes)


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DotProductAttentionClassifier(nn.Module):
    def __init__(self, hidden_dim, num_classes):
        super(DotProductAttentionClassifier, self).__init__()
        # Transformation layers for queries, keys, and values
        self.query_transform = nn.Linear(hidden_dim, hidden_dim)
        self.key_transform = nn.Linear(hidden_dim, hidden_dim)
        self.value_transform = nn.Linear(hidden_dim, hidden_dim)
        
        self.scale = torch.sqrt(torch.tensor(hidden_dim, dtype=torch.float32))  # Scaling factor for dot product
        
        # Fully connected layers
        self.fc1 = nn.Linear(hidden_dim, 512)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(512, 128)
        self.dropout = nn.Dropout(0.1)
        self.fc3 = nn.Linear(128, num_classes)

    def forward(self, x):
        # x shape: [batch_size, seq_length, hidden_dim]
        
        # Transform input for queries, keys, values
        queries = self.query_transform(x)
        keys = self.key_transform(x)
        values = self.value_transform(x)
        
        # Apply dot product attention
        scores = torch.bmm(queries, keys.transpose(1, 2)) / self.scale
        attn_weights = F.softmax(scores, dim=-1)
        context = torch.bmm(attn_weights, values)  # Context is the weighted sum of value vectors
        
        # Reduce the sequence dimension by averaging
        x = torch.mean(context, dim=1)
        
        # Pass through fully connected layers
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc3(x)
        return x

hidden_dim = 768
num_classes = 4

# Create the classifier with dot product attention
attention_classifier = DotProductAttentionClassifier(hidden_dim, num_classes)


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from gaussian_adaptive_attention import MultiHeadGaussianAdaptiveAttention

class ClassifierWithGAAM(nn.Module):
    def __init__(self, hidden_dim, num_classes, num_gaussians, norm_axis):
        super(ClassifierWithGAAM, self).__init__()
        self.gaam = MultiHeadGaussianAdaptiveAttention(
            norm_axis=norm_axis,  # Choose either -2 (sequence length) or 0 (layer number)
            num_heads=5,
            num_gaussians=num_gaussians,
            padding_value=0
        )
        
        # Fully connected layers
        self.fc1 = nn.Linear(hidden_dim, 512)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(512, 128)
        self.dropout = nn.Dropout(0.1)
        self.fc3 = nn.Linear(128, num_classes)

               # Initialize GAAM weights
        self._init_gaam_weights()

    def _init_gaam_weights(self):
        for name, param in self.gaam.named_parameters():
            if 'weight' in name:
                init.xavier_uniform_(param)
            elif 'bias' in name:
                init.constant_(param, 0)

    def forward(self, x):
        # x shape: [batch_size, seq_length, hidden_dim]
        
        # Apply Gaussian Adaptive Attention
        context = self.gaam(x)
        
        # Reduce the sequence dimension by averaging
        x = torch.mean(context, dim=1)
        
        # Pass through fully connected layers
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc3(x)
        return x

# Define the parameters for your GAAM classifier
hidden_dim = 768
num_classes = 4
num_gaussians = 3
norm_axis = -2

# Create the classifier with GAAM
gaam_classifier3 = ClassifierWithGAAM(hidden_dim, num_classes, num_gaussians, norm_axis)


In [8]:
from gaussian_adaptive_attention import GaussianAdaptiveAttention

class ClassifierWithGaussianAttention(nn.Module):
    def __init__(self, hidden_dim, num_classes, num_gaussians, norm_axis):
        super(ClassifierWithGaussianAttention, self).__init__()
        self.attention = GaussianAdaptiveAttention(
            norm_axis=norm_axis,
            num_heads=1,  # Since GaussianAttention is not multi-headed
            num_gaussians=num_gaussians,
            padding_value=0  # You can set this according to your data
        )
        
        # Fully connected layers
        self.fc1 = nn.Linear(hidden_dim, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, num_classes)
        
        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # x shape: [batch_size, seq_length, hidden_dim]
        
        # Apply Gaussian Adaptive Attention
        context = self.attention(x)
        
        # Reduce the sequence dimension by averaging
        x = torch.mean(context, dim=1)
        
        # Pass through fully connected layers
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = F.dropout(x, p=0.1)
        x = self.fc3(x)
        return x

# Define the parameters for your Gaussian Attention classifier
hidden_dim = 768
num_classes = 4
num_gaussians = 128
norm_axis = -2  # Assuming you want to normalize along the sequence length axis

# Create the classifier with Gaussian Attention
gaussian_classifier = ClassifierWithGaussianAttention(hidden_dim, num_classes, num_gaussians, norm_axis)

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GaussianAttentionClassifier(nn.Module):
    def __init__(self, hidden_dim, num_classes):
        super(GaussianAttentionClassifier, self).__init__()
        self.hidden_dim = hidden_dim

        # Transformation layer for values
        self.value_transform = nn.Linear(hidden_dim, hidden_dim)
        
        # Parameters for Gaussian attention
        self.mean = nn.Parameter(torch.rand(1))  # Learnable mean
        self.std_dev = nn.Parameter(torch.rand(1))  # Learnable standard deviation
        
        # Fully connected layers
        self.fc1 = nn.Linear(hidden_dim, 512)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(512, 128)
        self.dropout = nn.Dropout(0.1)
        self.fc3 = nn.Linear(128, num_classes)

    def forward(self, x, mask=None):
        batch_size, seq_length, _ = x.shape
        
        # Dynamically create position indices based on sequence length
        position_indices = torch.arange(seq_length, device=x.device).float().unsqueeze(0)  # [1, seq_length]

        # Transform input for values
        values = self.value_transform(x)  # [batch_size, seq_length, hidden_dim]

        # Calculate Gaussian attention weights
        position_diff = position_indices - self.mean  # [1, seq_length]
        exponent = -0.5 * ((position_diff / self.std_dev) ** 2)
        gaussian_weights = torch.exp(exponent)  # [1, seq_length]
        gaussian_weights = gaussian_weights / gaussian_weights.sum()  # Normalize weights
        gaussian_weights = gaussian_weights.repeat(batch_size, 1).unsqueeze(1)  # [batch_size, 1, seq_length]

        # Apply masking if provided
        if mask is not None:
            mask = mask.float().unsqueeze(1)  # [batch_size, 1, seq_length]
            gaussian_weights *= mask  # Apply mask to Gaussian weights

        # Apply Gaussian attention
        context = torch.bmm(gaussian_weights, values)  # [batch_size, 1, hidden_dim]
        x = context.squeeze(1)  # [batch_size, hidden_dim]

        # Pass through fully connected layers
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc3(x)
        return x

# Parameters
hidden_dim = 768
num_classes = 4

# Create the classifier with Gaussian attention
gaussian_attention_classifier = GaussianAttentionClassifier(hidden_dim, num_classes)

# Example usage
# Assuming 'input_tensor' is your input batch of shape [batch_size, seq_length, hidden_dim]
# 'mask' is a Boolean tensor of shape [batch_size, seq_length] indicating valid positions (True for valid, False for padded)
# output = gaussian_attention_classifier(input_tensor, mask)


In [10]:
import torch.optim as optim

# Load data
with open('processed_data.pkl', 'rb') as f:
    processed_data = pickle.load(f)

# Setup optimizer and loss function
optimizer = optim.Adam(classifier.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Training classifier
for x, label in processed_data:
    flattened = x.view(1,-1,768)
    output = gaam_classifier(flattened)
    print(f"Predicted: {torch.argmax(output)}, Actual: {label}")


Predicted: 0, Actual: 1
Predicted: 0, Actual: 0
Predicted: 0, Actual: 1
Predicted: 0, Actual: 1


In [11]:
from datasets import Dataset, DatasetDict, load_dataset
import pandas as pd

dataset = load_dataset('ag_news')


def take_a_percentage_of_data(dataset, percentage=0.1, shuffle=True, random_state=None):
    df = pd.DataFrame(dataset)
    df_sorted = df.sort_values(by='label')
    grouped = df_sorted.groupby('label')

    # ensure that proportions of the groups remains the same
    filtered_dfs = []
    for label, group in grouped:
        num_samples_to_keep = int(len(group) * percentage)
        filtered_group = group.head(num_samples_to_keep)
        filtered_dfs.append(filtered_group)

    filtered_df = pd.concat(filtered_dfs)
    if shuffle:
        filtered_df = filtered_df.sample(frac=1, random_state=random_state)

    # filtered_df = pd.concat(filtered_dfs)
    filtered_df.reset_index(drop=True, inplace=True)
    filtered_dict = filtered_df.to_dict(orient='list')
    filtered_dataset = Dataset.from_dict(filtered_dict)
    return filtered_dataset

dataset_train_1percent = take_a_percentage_of_data(dataset['train'], percentage=0.001)
dataset_test_1percent = take_a_percentage_of_data(dataset['test'], percentage=0.01)

combined_dataset_1percent = DatasetDict({
    'train': dataset_train_1percent,
    'test': dataset_test_1percent
})

In [12]:
print(len(combined_dataset_1percent['train']))
print(len(combined_dataset_1percent['test']))

120
76


In [13]:
# Process each sentence
processed_data = []
for obj in combined_dataset_1percent['test']:
    label = obj['label']
    text = obj['text']
    input_ids = tokenizer.encode(text, return_tensors='pt')
    with torch.no_grad():
        outputs = model(input_ids)
        print(outputs.shape)
       

    processed_data.append((outputs, label))

# Save the processed data
with open('processed_test_data.pkl', 'wb') as f:
    pickle.dump(processed_data, f)

torch.Size([13, 76, 768])
torch.Size([13, 56, 768])
torch.Size([13, 65, 768])
torch.Size([13, 62, 768])
torch.Size([13, 40, 768])
torch.Size([13, 37, 768])
torch.Size([13, 79, 768])
torch.Size([13, 32, 768])
torch.Size([13, 46, 768])
torch.Size([13, 42, 768])
torch.Size([13, 41, 768])
torch.Size([13, 60, 768])
torch.Size([13, 73, 768])
torch.Size([13, 54, 768])
torch.Size([13, 55, 768])
torch.Size([13, 64, 768])
torch.Size([13, 114, 768])
torch.Size([13, 48, 768])
torch.Size([13, 40, 768])
torch.Size([13, 53, 768])
torch.Size([13, 58, 768])
torch.Size([13, 66, 768])
torch.Size([13, 53, 768])
torch.Size([13, 34, 768])
torch.Size([13, 38, 768])
torch.Size([13, 48, 768])
torch.Size([13, 52, 768])
torch.Size([13, 50, 768])
torch.Size([13, 48, 768])
torch.Size([13, 52, 768])
torch.Size([13, 51, 768])
torch.Size([13, 35, 768])
torch.Size([13, 66, 768])
torch.Size([13, 68, 768])
torch.Size([13, 55, 768])
torch.Size([13, 54, 768])
torch.Size([13, 69, 768])
torch.Size([13, 37, 768])
torch.Size(

In [14]:
# Process each sentence
processed_data = []
for obj in combined_dataset_1percent['train']:
    label = obj['label']
    text = obj['text']
    input_ids = tokenizer.encode(text, return_tensors='pt')
    with torch.no_grad():
        outputs = model(input_ids)
        print(outputs.shape)
       

    processed_data.append((outputs, label))

# Save the processed data
with open('processed_data.pkl', 'wb') as f:
    pickle.dump(processed_data, f)

torch.Size([13, 36, 768])
torch.Size([13, 79, 768])
torch.Size([13, 42, 768])
torch.Size([13, 54, 768])
torch.Size([13, 49, 768])
torch.Size([13, 26, 768])
torch.Size([13, 59, 768])
torch.Size([13, 59, 768])
torch.Size([13, 69, 768])
torch.Size([13, 27, 768])
torch.Size([13, 64, 768])
torch.Size([13, 43, 768])
torch.Size([13, 64, 768])
torch.Size([13, 61, 768])
torch.Size([13, 55, 768])
torch.Size([13, 72, 768])
torch.Size([13, 26, 768])
torch.Size([13, 46, 768])
torch.Size([13, 45, 768])
torch.Size([13, 50, 768])
torch.Size([13, 65, 768])
torch.Size([13, 44, 768])
torch.Size([13, 43, 768])
torch.Size([13, 74, 768])
torch.Size([13, 52, 768])
torch.Size([13, 78, 768])
torch.Size([13, 57, 768])
torch.Size([13, 44, 768])
torch.Size([13, 50, 768])
torch.Size([13, 38, 768])
torch.Size([13, 60, 768])
torch.Size([13, 47, 768])
torch.Size([13, 41, 768])
torch.Size([13, 40, 768])
torch.Size([13, 51, 768])
torch.Size([13, 81, 768])
torch.Size([13, 50, 768])
torch.Size([13, 40, 768])
torch.Size([

In [15]:
print('Classifier with GAAM 1:')
total_params = sum(p.numel() for p in classifier.parameters() if p.requires_grad)
print(f"Total learnable parameters: {total_params}")

print('Classifier with GAAM 2:')
total_params = sum(p.numel() for p in gaam_classifier.parameters() if p.requires_grad)
print(f"Total learnable parameters: {total_params}")

print('Classifier with GAAM 3:')
total_params = sum(p.numel() for p in gaam_classifier3.parameters() if p.requires_grad)
print(f"Total learnable parameters: {total_params}")

print('Classifier with GAAM 4:')
total_params = sum(p.numel() for p in gaussian_classifier.parameters() if p.requires_grad)
print(f"Total learnable parameters: {total_params}")

print('GPT2 Classifier:')
total_params = sum(p.numel() for p in basic_classifier.parameters() if p.requires_grad)
print(f"Total learnable parameters: {total_params}")

print('Attention Classifier:')
total_params = sum(p.numel() for p in attention_classifier.parameters() if p.requires_grad)
print(f"Total learnable parameters: {total_params}")


Classifier with GAAM 1:
Total learnable parameters: 19106
Classifier with GAAM 2:
Total learnable parameters: 395852
Classifier with GAAM 3:
Total learnable parameters: 459938
Classifier with GAAM 4:
Total learnable parameters: 460164
GPT2 Classifier:
Total learnable parameters: 459908
Attention Classifier:
Total learnable parameters: 2231684


In [121]:
print(attention_classifier)
print(basic_classifier)
print(gaam_classifier)
print(classifier)
print(gaam_classifier3)

DotProductAttentionClassifier(
  (query_transform): Linear(in_features=768, out_features=768, bias=True)
  (key_transform): Linear(in_features=768, out_features=768, bias=True)
  (value_transform): Linear(in_features=768, out_features=768, bias=True)
  (fc1): Linear(in_features=768, out_features=512, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=512, out_features=128, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (fc3): Linear(in_features=128, out_features=4, bias=True)
)
GPT2Classifier(
  (fc1): Linear(in_features=768, out_features=512, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=512, out_features=128, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (fc3): Linear(in_features=128, out_features=4, bias=True)
)
GAAAttentionClassifier(
  (gaam): MultiHeadGaussianAdaptiveAttention(
    (attention_heads): ModuleList(
      (0-11): 12 x GaussianAdaptiveAttention()
    )
  )
  (fc1): Linear(in_features=768, out_features=512, bias=True)
  (relu): ReLU()


In [16]:
import torch.optim as optim
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

use_classifier = gaussian_attention_classifier

with open('processed_data.pkl', 'rb') as f:
    processed_data = pickle.load(f)

optimizer = optim.Adam(use_classifier.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

# Assuming the classifier is ready and 'device' is defined
use_classifier.to(device)
use_classifier.train()

# Prepare data
features, labels = zip(*processed_data)
batch_size = 32
num_batches = len(features) // batch_size

num_epochs = 10 

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")

    for i in range(0, len(features), batch_size):
        optimizer.zero_grad()  
        total_loss = 0

        for f, label in zip(features[i:i+batch_size], labels[i:i+batch_size]):
            x = torch.tensor(f).view(1, -1, 768).to(device) 
            y = torch.tensor([label], dtype=torch.long).to(device)

            output = use_classifier(x)
            loss = criterion(output, y)
            loss.backward()  
            total_loss += loss.item()

        optimizer.step()

        print(f"Batch {i // batch_size + 1}/{num_batches + 1}: Average Loss: {total_loss / batch_size:.2f}")


Epoch 1/10


  x = torch.tensor(f).view(1, -1, 768).to(device)


Batch 1/3: Average Loss: 1.40
Batch 2/3: Average Loss: 2.02
Batch 3/3: Average Loss: 5.46
Batch 4/3: Average Loss: 5.14
Epoch 2/10
Batch 1/3: Average Loss: 9.54
Batch 2/3: Average Loss: 3.46
Batch 3/3: Average Loss: 4.70
Batch 4/3: Average Loss: 1.81
Epoch 3/10
Batch 1/3: Average Loss: 2.49
Batch 2/3: Average Loss: 2.23
Batch 3/3: Average Loss: 1.39
Batch 4/3: Average Loss: 1.20
Epoch 4/10
Batch 1/3: Average Loss: 1.34
Batch 2/3: Average Loss: 1.37
Batch 3/3: Average Loss: 1.43
Batch 4/3: Average Loss: 1.22
Epoch 5/10
Batch 1/3: Average Loss: 1.32
Batch 2/3: Average Loss: 1.30
Batch 3/3: Average Loss: 1.11
Batch 4/3: Average Loss: 1.10
Epoch 6/10
Batch 1/3: Average Loss: 1.30
Batch 2/3: Average Loss: 1.28
Batch 3/3: Average Loss: 1.05
Batch 4/3: Average Loss: 1.11
Epoch 7/10
Batch 1/3: Average Loss: 1.31
Batch 2/3: Average Loss: 1.10
Batch 3/3: Average Loss: 0.90
Batch 4/3: Average Loss: 1.05
Epoch 8/10
Batch 1/3: Average Loss: 1.29
Batch 2/3: Average Loss: 1.05
Batch 3/3: Average Loss

In [17]:
import torch.optim as optim
import torch
import pickle

with open('processed_data.pkl', 'rb') as f:
    processed_data = pickle.load(f)

use_classifier.to(device)
use_classifier.eval()

correct_predictions = 0
total_predictions = 0

for x, label in processed_data:
    x = x.to(device)
    label = torch.tensor([label]).to(device)
    
    with torch.no_grad():
        flattened = x.view(1, -1, 768)
        output = use_classifier(flattened)
        predicted_label = torch.argmax(output, dim=1)
        print(f"Predicted: {predicted_label.item()}, Actual: {label.item()}")

    correct_predictions += (predicted_label == label).sum().item()
    total_predictions += label.size(0)

accuracy = correct_predictions / total_predictions
print(f"Accuracy: {accuracy:.4f}")


Predicted: 0, Actual: 1
Predicted: 1, Actual: 1
Predicted: 0, Actual: 1
Predicted: 2, Actual: 3
Predicted: 1, Actual: 1
Predicted: 3, Actual: 0
Predicted: 1, Actual: 1
Predicted: 1, Actual: 1
Predicted: 2, Actual: 0
Predicted: 3, Actual: 3
Predicted: 2, Actual: 2
Predicted: 2, Actual: 2
Predicted: 2, Actual: 2
Predicted: 2, Actual: 1
Predicted: 2, Actual: 0
Predicted: 2, Actual: 1
Predicted: 3, Actual: 3
Predicted: 3, Actual: 1
Predicted: 3, Actual: 0
Predicted: 0, Actual: 3
Predicted: 2, Actual: 0
Predicted: 2, Actual: 2
Predicted: 2, Actual: 2
Predicted: 2, Actual: 3
Predicted: 2, Actual: 1
Predicted: 1, Actual: 1
Predicted: 2, Actual: 0
Predicted: 2, Actual: 3
Predicted: 2, Actual: 2
Predicted: 0, Actual: 1
Predicted: 2, Actual: 2
Predicted: 0, Actual: 2
Predicted: 3, Actual: 3
Predicted: 3, Actual: 3
Predicted: 0, Actual: 0
Predicted: 2, Actual: 3
Predicted: 2, Actual: 2
Predicted: 0, Actual: 0
Predicted: 1, Actual: 1
Predicted: 2, Actual: 3
Predicted: 0, Actual: 3
Predicted: 2, Ac

In [18]:
import torch.optim as optim
import torch
import pickle

with open('processed_test_data.pkl', 'rb') as f:
    processed_data = pickle.load(f)

use_classifier.to(device)
use_classifier.eval()

correct_predictions = 0
total_predictions = 0

for x, label in processed_data:
    x = x.to(device)
    label = torch.tensor([label]).to(device)
    
    with torch.no_grad():
        flattened = x.view(1, -1, 768)
        output = use_classifier(flattened)
        predicted_label = torch.argmax(output, dim=1)
        print(f"Predicted: {predicted_label.item()}, Actual: {label.item()}")

    correct_predictions += (predicted_label == label).sum().item()
    total_predictions += label.size(0)

accuracy = correct_predictions / total_predictions
print(f"Accuracy: {accuracy:.4f}")


Predicted: 2, Actual: 1
Predicted: 1, Actual: 3
Predicted: 2, Actual: 0
Predicted: 2, Actual: 3
Predicted: 0, Actual: 1
Predicted: 0, Actual: 1
Predicted: 2, Actual: 2
Predicted: 1, Actual: 0
Predicted: 2, Actual: 2
Predicted: 0, Actual: 3
Predicted: 3, Actual: 2
Predicted: 2, Actual: 0
Predicted: 1, Actual: 2
Predicted: 2, Actual: 2
Predicted: 2, Actual: 2
Predicted: 2, Actual: 3
Predicted: 1, Actual: 3
Predicted: 2, Actual: 1
Predicted: 1, Actual: 3
Predicted: 2, Actual: 3
Predicted: 1, Actual: 0
Predicted: 2, Actual: 2
Predicted: 0, Actual: 1
Predicted: 0, Actual: 3
Predicted: 3, Actual: 1
Predicted: 2, Actual: 0
Predicted: 2, Actual: 0
Predicted: 2, Actual: 0
Predicted: 1, Actual: 2
Predicted: 2, Actual: 3
Predicted: 0, Actual: 1
Predicted: 2, Actual: 1
Predicted: 2, Actual: 1
Predicted: 2, Actual: 1
Predicted: 1, Actual: 0
Predicted: 0, Actual: 2
Predicted: 2, Actual: 0
Predicted: 0, Actual: 0
Predicted: 3, Actual: 2
Predicted: 2, Actual: 2
Predicted: 3, Actual: 0
Predicted: 3, Ac

In [19]:
import torch

# Training code goes here...

# After training, save the model
torch.save(gaussian_classifier.state_dict(), 'gaussian_classifier.pth')
