In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [None]:
class SimModel(nn.Module):
    def __init__(self, input_size=1, latent_size=1, is_inhibitor=False, signed_inhibitor=True):
        super(SimModel, self).__init__()
        self.input_size = input_size
        self.latent_size = latent_size
        self.is_inhibitor = is_inhibitor
        self.signed_inhibitor = signed_inhibitor
        
        self.q_lin = nn.Linear(input_size, latent_size)
        self.k_lin = nn.Linear(input_size, latent_size)
        self.v_lin = nn.Linear(input_size, latent_size)
        self.out_lin = nn.Linear(latent_size, latent_size)
        
    def compute_signed_inhibitor(self, scores: torch.Tensor, v: torch.Tensor, v_t: torch.Tensor) -> torch.Tensor:
        pos_v = F.relu(v_t)  # Positive part of V
        neg_v = -F.relu(-v_t)  # Negative part of V
        v_sum = torch.sum(v, dim=-2, keepdim=True)  # Sum over keys
        dist1 = torch.cdist(scores, pos_v, p=1)  # Distance to positive V
        dist2 = torch.cdist(scores, neg_v, p=1)  # Distance to negative V
        context = 0.5 * (v_sum + dist1 - dist2)
        return context

    def compute_unsigned_inhibitor(self, scores: torch.Tensor, v: torch.Tensor, v_t: torch.Tensor) -> torch.Tensor:
        v_sum = torch.sum(v, dim=-2, keepdim=True)  # Sum over keys
        z_sum = torch.sum(scores, dim=-1, keepdim=True)  # Sum of scores
        abs_diff = torch.cdist(scores, v_t, p=1)  # Absolute differences
        context = 0.5 * (v_sum - z_sum + abs_diff)
        return context

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        q = self.q_lin(x)  # (B, S, latent_size)
        k = self.k_lin(x)  # (B, S, latent_size)
        v = self.v_lin(x)  # (B, S, latent_size)
        
        if self.is_inhibitor:

            scores = torch.cdist(q, k, p=1) / self.latent_size  # (B, S, S)
             
            scores = F.relu(scores - scores.mean(dim=-1, keepdim=True))  # (B, S, S)           

            dropout = nn.Dropout(p=0.1)  
            mask_dropped = dropout(mask.to(scores.dtype))  # (B, S)
            
            mask_expanded = (mask_dropped == 0).unsqueeze(1)  # (B, 1, S)
            
            M = torch.max(torch.abs(v)) + 1
            scores = scores.masked_fill(mask_expanded, M)  # (B, S, S)
            
           
            v_t = v.transpose(-1, -2)  # (B, latent_size, S)
            
            if self.signed_inhibitor:
                context = self.compute_signed_inhibitor(scores, v, v_t)  # (B, S, latent_size)
            else:
                context = self.compute_unsigned_inhibitor(scores, v, v_t)  # (B, S, latent_size)
        else:
            scores = torch.matmul(q, k.transpose(-2, -1)) / (self.latent_size ** 0.5)  # (B, S, S)
            
            mask_expanded = (mask == 0).unsqueeze(1)  # (B, 1, S)
            
            scores = scores.masked_fill(mask_expanded, float('-inf'))  # (B, S, S)
            
            weights = F.softmax(scores, dim=-1)  # (B, S, S)
            

            weights = F.dropout(weights, p=0.1, training=self.training)  # (B, S, S)
            

            context = torch.matmul(weights, v)  
        
        h = self.out_lin(context)  
        
        return h, scores

In [None]:
class CosineEmbeddingLossWrapper(nn.Module):
    def __init__(self):
        super(CosineEmbeddingLossWrapper, self).__init__()
        self.cosine_loss = nn.CosineEmbeddingLoss()

    def forward(self, student_context, target_context):
        # Reshape to (batch_size * seq_length, latent_size)
        student = student_context.view(-1, student_context.size(-1))
        target = target_context.view(-1, target_context.size(-1))
        
        # Labels: 1 for similar, since we want to align them
        labels = torch.ones(student.size(0)).to(student_context.device)
        
        loss = self.cosine_loss(student, target, labels)
        return loss

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

class ToyDataset(Dataset):
    def __init__(self, num_samples=1000, seq_length=10, input_size=1, latent_size=1):
        super(ToyDataset, self).__init__()
        self.num_samples = num_samples
        self.seq_length = seq_length
        self.input_size = input_size
        self.latent_size = latent_size
        
        # Random input data
        self.x = torch.randn(num_samples, seq_length, input_size)
        
        # Define target context as some function of x, for example, sum over the sequence
        # Here, we'll just use random target contexts for demonstration
        self.target_h = torch.randn(num_samples, seq_length, latent_size)
        
        # Attention mask: 1 for valid tokens, 0 for padding
        # For simplicity, assume all sequences are fully valid
        self.mask = torch.ones(num_samples, seq_length)
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return self.x[idx], self.mask[idx], self.target_h[idx]

In [None]:
import torch.optim as optim

# Hyperparameters
input_size = 1
latent_size = 1  # Increased for better representation
is_inhibitor = True
signed_inhibitor = True
num_epochs = 20
batch_size = 32
learning_rate = 1e-3

# Initialize the model
model = SimModel(
    input_size=input_size,
    latent_size=latent_size,
    is_inhibitor=is_inhibitor,
    signed_inhibitor=signed_inhibitor
)

# Move model to device
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#model.to(device)

# Initialize the loss function
loss_fn = CosineEmbeddingLossWrapper()

# Initialize the optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Initialize the dataset and dataloader
dataset = ToyDataset(num_samples=1000, seq_length=10, input_size=input_size, latent_size=latent_size)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training Loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for batch_idx, (x, mask, target_h) in enumerate(dataloader):
        x = x.to(device)  # (B, S, input_size)
        mask = mask.to(device)  # (B, S)
        target_h = target_h.to(device)  # (B, S, latent_size)
        
        # Forward pass
        h, scores = model(x, mask)  # h: (B, S, latent_size)
        
        # Compute loss
        loss = loss_fn(h, target_h)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

In [3]:
import numpy as np

def softmax(x):
    exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

class SimpleTransformer:
    def __init__(self, input_dim, hidden_dim, output_dim, attention_type='dot_product'):
        self.Wq = np.random.randn(input_dim, hidden_dim)
        self.Wk = np.random.randn(input_dim, hidden_dim)
        self.Wv = np.random.randn(input_dim, hidden_dim)
        self.Wo = np.random.randn(hidden_dim, output_dim)
        self.attention_type = attention_type

    def dot_product_attention(self, Q, K, V):
        attention_scores = np.dot(Q, K.T) / np.sqrt(K.shape[1])
        attention_probs = softmax(attention_scores)
        return np.dot(attention_probs, V)

    def inhibitor_attention(self, Q, K, V):
        Z = np.abs(Q[:, np.newaxis] - K)
        return np.maximum(V[:, np.newaxis] - Z, 0).sum(axis=1)

    def forward(self, X):
        Q = np.dot(X, self.Wq)
        K = np.dot(X, self.Wk)
        V = np.dot(X, self.Wv)

        if self.attention_type == 'dot_product':
            attention_output = self.dot_product_attention(Q, K, V)
        else:
            attention_output = self.inhibitor_attention(Q, K, V)

        return np.dot(attention_output, self.Wo)

# Set random seed for reproducibility
np.random.seed(42)

# Model parameters
input_dim = 4
hidden_dim = 2
output_dim = 3
seq_length = 5

# Create input data
X = np.random.randn(seq_length, input_dim)

# Initialize models
dot_product_model = SimpleTransformer(input_dim, hidden_dim, output_dim, 'dot_product')
inhibitor_model = SimpleTransformer(input_dim, hidden_dim, output_dim, 'inhibitor')

# Use the same weights for both models to ensure fair comparison
inhibitor_model.Wq = dot_product_model.Wq
inhibitor_model.Wk = dot_product_model.Wk
inhibitor_model.Wv = dot_product_model.Wv
inhibitor_model.Wo = dot_product_model.Wo

# Forward pass
dot_product_output = dot_product_model.forward(X)
inhibitor_output = inhibitor_model.forward(X)

print("Dot-product attention output:")
print(dot_product_output)
print("\nInhibitor attention output:")
print(inhibitor_output)

# Calculate cosine similarity
dot_product = np.sum(dot_product_output * inhibitor_output, axis=1)
norm_dot = np.linalg.norm(dot_product_output, axis=1)
norm_inhibitor = np.linalg.norm(inhibitor_output, axis=1)
cosine_similarity = dot_product / (norm_dot * norm_inhibitor)

print("\nCosine similarity between outputs:")
print(cosine_similarity)

# Calculate Euclidean distance
euclidean_distance = np.linalg.norm(dot_product_output - inhibitor_output, axis=1)
print("\nEuclidean distance between outputs:")
print(euclidean_distance)

# Calculate Mean Squared Error
mse = np.mean((dot_product_output - inhibitor_output)**2, axis=1)
print("\nMean Squared Error between outputs:")
print(mse)

Dot-product attention output:
[[ 3.59216446  1.42199006 -2.87967532]
 [ 3.77442754  1.49180249 -3.0543858 ]
 [-1.94609717 -0.86616058  0.38851636]
 [ 3.91596314  1.5419788  -3.23942727]
 [-1.97551454 -0.87585477  0.4359626 ]]

Inhibitor attention output:
[[ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.92758011  0.30151054 -1.54699329]
 [-1.77567688 -0.86451925 -0.55321843]
 [ 2.61997468  0.85162453 -4.36952366]]

Cosine similarity between outputs:
[        nan         nan -0.67359451 -0.59622408 -0.6873109 ]

Euclidean distance between outputs:
[4.81852998 5.07947345 3.65618306 6.73807961 6.86989113]

Mean Squared Error between outputs:
[ 7.73941039  8.60035019  4.45589152 15.13390561 15.73180138]


  cosine_similarity = dot_product / (norm_dot * norm_inhibitor)
