In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class STRAttention(nn.Module):
    def __init__(self, vit_embedding_dim, reduction_ratio=4):
        super(STRAttention, self).__init__()
        reduced_channels = vit_embedding_dim // reduction_ratio
        
        # Spatial Attention (1D since ViT outputs a sequence)
        self.conv_s = nn.Conv1d(vit_embedding_dim, reduced_channels, kernel_size=3, padding=1)
        self.conv1x1_s = nn.Conv1d(reduced_channels, vit_embedding_dim, kernel_size=1)
        
        # Temporal Attention (3D conv over time dimension)
        self.conv_t = nn.Conv1d(vit_embedding_dim, reduced_channels, kernel_size=3, padding=1)
        self.conv1x1_t = nn.Conv1d(reduced_channels, vit_embedding_dim, kernel_size=1)
        
        self.sigmoid = nn.Sigmoid()

    def forward(self, vit_encodings):
        # Input shape: (Batch, Time, Channels) from ViT encoder
        print(vit_encodings.shape)
        b, t, c = vit_encodings.shape
        x = vit_encodings.permute(0, 2, 1)  # Change to (Batch, Channels, Time)
        
        # Spatial Attention
        attn_s = self.conv_s(x)
        attn_s = self.conv1x1_s(attn_s)
        attn_s = self.sigmoid(attn_s)
        
        # Temporal Attention
        attn_t = self.conv_t(x)
        attn_t = self.conv1x1_t(attn_t)
        attn_t = self.sigmoid(attn_t)
        
        # Apply attention weights
        x = x * attn_s * attn_t
        return x.permute(0, 2, 1)  # Back to (Batch, Time, Channels)

In [17]:
import os
import torch
import timm
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import numpy as np

# Define image transformation (same as ViT expects)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Custom Dataset for loading images from a folder
class ImageDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.image_files = sorted([f for f in os.listdir(folder_path) if f.endswith(('.jpg', '.png', '.tif'))])
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.folder_path, self.image_files[idx])
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, self.image_files[idx]  # Return filename as well for reference

# Load pre-trained ViT for feature extraction
vit_model = timm.create_model('vit_base_patch16_224', pretrained=True)
vit_model.head = torch.nn.Identity()  # Remove classification head
vit_model.eval()

# Function to process images through ViT
def extract_vit_embeddings(dataloader, device):
    embeddings = []
    filenames = []
    with torch.no_grad():
        for images, file_names in dataloader:
            images = images.to(device)
            features = vit_model(images)  # Extract ViT embeddings
            embeddings.append(features.cpu().numpy())
            filenames.extend(file_names)
    return np.array(embeddings), filenames

# Function to group frames into sequences (for temporal attention)
def create_sequences(embeddings, seq_len=5):
    sequences = []
    for i in range(len(embeddings) - seq_len + 1):
        sequences.append(embeddings[i:i + seq_len])
    return np.array(sequences)

# Full pipeline function
def process_images_with_attention(folder_path, seq_len=5, device='cuda'):
    device = torch.device(device if torch.cuda.is_available() else "cpu")

    # Load dataset and create DataLoader
    dataset = ImageDataset(folder_path, transform=transform)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=False)
    print(len(dataset))

    # Extract ViT embeddings
    vit_embeddings, filenames = extract_vit_embeddings(dataloader, device)

    # Reshape embeddings for spatiotemporal attention (Batch, Time, Channels)
    sequences = create_sequences(vit_embeddings, seq_len)
    sequences_tensor = torch.tensor(sequences, dtype=torch.float32).to(device)

    # Initialize and apply STRAttention
    str_attention = STRAttention(vit_embedding_dim=768).to(device)
    with torch.no_grad():
        output_embeddings = str_attention(sequences_tensor)

    return output_embeddings.cpu().numpy(), filenames

In [18]:
folder_path = "/home/sagemaker-user/rc/data/test"
output_embeddings, filenames = process_images_with_attention(folder_path)

print("Processed embeddings shape:", output_embeddings.shape)

['001.tif', '002.tif', '003.tif', '004.tif', '005.tif', '006.tif', '007.tif', '008.tif', '009.tif', '010.tif']
10


ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2,) + inhomogeneous part.

In [5]:
from torchvision import transforms
from dataset import ImageDataset

transform = transforms.Compose([
    transforms.Resize((150, 150)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor()
])

dataset = ImageDataset(
    root_dir = "data/UCSDped1/Train/",
    transform = transform
)

dataloader = DataLoader

In [7]:
dataset[1].shape

torch.Size([10, 3, 150, 150])