In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttention(nn.Module):
    def __init__(self, dim, context_dim):
        super().__init__()
        self.to_q = nn.Linear(dim, dim, bias=False)
        self.to_k = nn.Linear(context_dim, dim, bias=False)
        self.to_v = nn.Linear(context_dim, dim, bias=False)
        self.scale = dim ** -0.5  

    def forward(self, x, context):
        q = self.to_q(x)
        k = self.to_k(context)
        v = self.to_v(context)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = attn @ v
        return out + x  # Residual Connection

def test_cross_attention():
    dim = 64  # Feature dimension
    context_dim = 64  # Context dimension
    batch_size = 2  # Batch size
    seq_len_x = 10  # Sequence length of input
    seq_len_ctx = 15  # Sequence length of context
    
    # Create dummy input tensors
    x = torch.randn(batch_size, seq_len_x, dim)
    context = torch.randn(batch_size, seq_len_ctx, context_dim)
    
    # Initialize CrossAttention module
    cross_attn = CrossAttention(dim, context_dim)
    
    # Forward pass
    output = cross_attn(x, context)
    
    # Print input and output
    print("Input (x):", x.shape)
    print("Context:", context.shape)
    print("Output:", output.shape)
    
    # Check output shape
    assert output.shape == x.shape, f"Expected shape {x.shape}, but got {output.shape}"
    
    # Check if attention mechanism preserves input size
    attn_scores = (cross_attn.to_q(x) @ cross_attn.to_k(context).transpose(-2, -1)) * cross_attn.scale
    attn_probs = attn_scores.softmax(dim=-1)
    assert torch.allclose(attn_probs.sum(dim=-1), torch.ones_like(attn_probs.sum(dim=-1)), atol=1e-5), "Attention probabilities should sum to 1"
    
    print("All tests passed!")

test_cross_attention()


Input (x): torch.Size([2, 10, 64])
Context: torch.Size([2, 15, 64])
Output: torch.Size([2, 10, 64])
All tests passed!


In [8]:
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import io

# Đọc file Parquet
df = pd.read_parquet("0000.parquet")

# Giả sử cột chứa ảnh là 'image_bytes'
img_data = df.iloc[0]["image"]  # Lấy ảnh từ dòng đầu tiên

# Chuyển bytes thành ảnh và hiển thị
image = Image.open(io.BytesIO(img_data))
plt.imshow(image)
plt.axis("off")  # Ẩn trục tọa độ
plt.show()


TypeError: a bytes-like object is required, not 'dict'