# Convert data to format

- images to temporal images
- text to encodings

In [None]:
import numpy as np 
from transformers import AutoTokenizer
from torch.utils.data import Dataset
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader
import torch

String conversion

In [None]:
tokenizer = AutoTokenizer.from_pretrained("t5-small")  # or "gpt2", etc.

text = ["rough surface", "sliding contact", "object tilting"]
tokenized = tokenizer(text, padding="max_length", truncation=True, return_tensors="pt")
print(tokenized)
input_ids=tokenized['input_ids']
print(input_ids)  
decoded=tokenizer.decode(input_ids[0].squeeze(), skip_special_tokens=True)
print(decoded)

{'input_ids': tensor([[ 8678,  1774,     1,  ...,     0,     0,     0],
        [13441,   574,     1,  ...,     0,     0,     0],
        [ 3735, 20382,    53,  ...,     0,     0,     0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}
tensor([[ 8678,  1774,     1,  ...,     0,     0,     0],
        [13441,   574,     1,  ...,     0,     0,     0],
        [ 3735, 20382,    53,  ...,     0,     0,     0]])
rough surface


In [13]:
def tokenize_captions(captions, max_len=30):
    return tokenizer(
        captions,
        padding="max_length",
        truncation=True,
        max_length=max_len,
        return_tensors="pt"
    )

# Train model to predict text

In [15]:
class TactileDataset(Dataset):
    def __init__(self, images, captions, tokenizer, max_len=30):
        self.images = images
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.captions = captions

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        caption = self.captions[idx]

        tokens = self.tokenizer(caption, max_length=self.max_len,
                                padding="max_length", truncation=True,
                                return_tensors="pt")
        
        return {
            "image": image,
            "input_ids": tokens["input_ids"].squeeze(0),
            "attention_mask": tokens["attention_mask"].squeeze(0)
        }


In [18]:
class TactileCaptioningModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=256, hidden_dim=512, max_len=30):
        super().__init__()
        # Image encoder: ResNet18 without classifier
        resnet = models.resnet18(pretrained=True)
        resnet.fc = nn.Identity()
        self.encoder = resnet

        # Project image features
        self.img_proj = nn.Linear(512, embed_dim)

        # Text decoder: Embedding + Transformer Decoder
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=4)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=2)

        self.output_layer = nn.Linear(embed_dim, vocab_size)

    def forward(self, images, input_ids):
        batch_size = images.size(0)
        image_features = self.encoder(images)  # (B, 512)
        image_features = self.img_proj(image_features)  # (B, E)
        image_features = image_features.unsqueeze(1)  # (B, 1, E)

        tgt = self.embedding(input_ids)  # (B, T, E)
        tgt = tgt.transpose(0, 1)  # (T, B, E)
        memory = image_features.transpose(0, 1)  # (1, B, E)

        out = self.transformer_decoder(tgt, memory)  # (T, B, E)
        out = out.transpose(0, 1)  # (B, T, E)

        return self.output_layer(out)

In [None]:
model = TactileCaptioningModel(vocab_size=tokenizer.vocab_size)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

dataset = TactileDataset(X, y, tokenizer)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

for epoch in range(10):
    model.train()
    total_loss = 0
    for batch in dataloader:
        images = batch["image"].to(device)
        input_ids = batch["input_ids"].to(device)

        outputs = model(images, input_ids[:, :-1])
        loss = loss_fn(outputs.reshape(-1, tokenizer.vocab_size), input_ids[:, 1:].reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch}: Loss = {total_loss / len(dataloader):.4f}")

# Use the text with another LLM