In [1]:
!pip install -q datasets transformers peft accelerate bitsandbytes torchmetrics

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.1/76.1 MB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m961.5/961.5 kB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m22.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
from datasets import load_dataset
from tqdm import tqdm
from transformers import T5TokenizerFast
from transformers import CLIPProcessor, CLIPModel, T5Tokenizer, T5ForConditionalGeneration
from peft import get_peft_model, LoraConfig, TaskType
from torchmetrics.multimodal import CLIPScore as TorchCLIPScore

In [3]:
from transformers import AutoTokenizer, T5ForConditionalGeneration

In [25]:
# Load models
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
"""
tokenizer = T5TokenizerFast.from_pretrained("t5-base")
latex_tokens = [
    "\\documentclass", "\\usepackage", "\\begin", "\\end",
    "\\tikzstyle", "\\node", "\\draw", "\\fill", "\\path",
    "\\pgfplotsset", "\\setlength", "\\foreach", "\\addplot",
    "tikztonodes", "tikztarget", "\\fill", "\\node",
    "\\begin{tikzpicture}", "\\end{tikzpicture}",
    "{", "}", "[", "]", "(", ")", "%", ";", "\\n"
]
tokenizer.add_tokens(latex_tokens)
"""
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base", trust_remote_code=True)
t5_model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-base")
t5_model.config.decoder_start_token_id = tokenizer.pad_token_id


#t5_model = T5ForConditionalGeneration.from_pretrained("t5-base")
#t5_model.resize_token_embeddings(len(tokenizer))

# Apply LoRA to T5 decoder
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q", "v"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)
t5_model = get_peft_model(t5_model, lora_config)

In [26]:
class TikZGenModel(nn.Module):
    def __init__(self, clip_model, t5_model):
        super().__init__()
        self.clip_model = clip_model
        self.t5_model = t5_model
        self.fusion = nn.Linear(clip_model.config.projection_dim + t5_model.config.d_model, t5_model.config.d_model)

    def forward(self, image, input_ids, attention_mask, labels=None):
        with torch.no_grad():
            image_features = self.clip_model.get_image_features(pixel_values=image)

        text_features = self.t5_model.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
        combined_features = torch.cat([text_features, image_features], dim=1)
        fused = self.fusion(combined_features).unsqueeze(1)

        return self.t5_model(inputs_embeds=fused, labels=labels)

In [6]:
# Load full dataset (or first 15K examples)
full_data = load_dataset("nllg/datikz-v3", split="train[:2000]").shuffle(seed=42)

# Split 90% train / 10% test
split = full_data.train_test_split(test_size=0.1, seed=42)
train_raw = split["train"]
test_raw = split["test"]

README.md:   0%|          | 0.00/1.38k [00:00<?, ?B/s]

train-00000-of-00014.parquet:   0%|          | 0.00/412M [00:00<?, ?B/s]

train-00001-of-00014.parquet:   0%|          | 0.00/442M [00:00<?, ?B/s]

train-00002-of-00014.parquet:   0%|          | 0.00/467M [00:00<?, ?B/s]

train-00003-of-00014.parquet:   0%|          | 0.00/461M [00:00<?, ?B/s]

train-00004-of-00014.parquet:   0%|          | 0.00/381M [00:00<?, ?B/s]

train-00005-of-00014.parquet:   0%|          | 0.00/470M [00:00<?, ?B/s]

train-00006-of-00014.parquet:   0%|          | 0.00/470M [00:00<?, ?B/s]

train-00007-of-00014.parquet:   0%|          | 0.00/462M [00:00<?, ?B/s]

train-00008-of-00014.parquet:   0%|          | 0.00/453M [00:00<?, ?B/s]

train-00009-of-00014.parquet:   0%|          | 0.00/447M [00:00<?, ?B/s]

train-00010-of-00014.parquet:   0%|          | 0.00/460M [00:00<?, ?B/s]

train-00011-of-00014.parquet:   0%|          | 0.00/458M [00:00<?, ?B/s]

train-00012-of-00014.parquet:   0%|          | 0.00/451M [00:00<?, ?B/s]

train-00013-of-00014.parquet:   0%|          | 0.00/441M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/33.5M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/145366 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/542 [00:00<?, ? examples/s]

In [7]:
preprocessed_train = []
preprocessed_test = []

def preprocess_dataset(dataset_split):
    processed = []
    for example in dataset_split:
        image = example['image'].convert("RGB")
        image_tensor = clip_processor(images=image, return_tensors="pt")['pixel_values'].squeeze(0)

        inputs = tokenizer(example['caption'], return_tensors="pt", padding="max_length", truncation=True, max_length=256)
        latex_code = example['code']
        # Optionally normalize if any weird characters
        if isinstance(latex_code, str):
            latex_code = latex_code.replace("\r\n", "\n").replace("\r", "\n")  # clean CRLF

        targets = tokenizer(
            latex_code,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=256  # increase to avoid cutting code
        )
        """
        print("Raw TikZ Code:", example['code'])
        decoded_tokens = tokenizer.convert_ids_to_tokens(targets['input_ids'].squeeze(0))
        print("Tokenized TikZ (tokens):", decoded_tokens)
        """
        processed.append({
            'image': image_tensor,
            'input_ids': inputs['input_ids'].squeeze(0),
            'attention_mask': inputs['attention_mask'].squeeze(0),
            'labels': targets['input_ids'].squeeze(0),
            'caption': example['caption'],
            'reference': example['code']
        })
    return processed

preprocessed_train = preprocess_dataset(train_raw)
preprocessed_test = preprocess_dataset(test_raw)

In [8]:
from torch.utils.data import Dataset

class TikzDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

dataset_torch = TikzDataset(preprocessed_train)

In [9]:
def collate_fn(batch):
    return {
        'image': torch.stack([x['image'] for x in batch]),
        'input_ids': torch.stack([x['input_ids'] for x in batch]),
        'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
        'labels': torch.stack([x['labels'] for x in batch]),
        'caption': [x['caption'] for x in batch],
        'reference': [x['reference'] for x in batch],
    }

loader = DataLoader(dataset_torch, batch_size=4, shuffle=True, collate_fn=collate_fn)

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TikZGenModel(clip_model, t5_model).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
loader = DataLoader(dataset_torch, batch_size=4, shuffle=True, collate_fn=collate_fn)

model.train()
for epoch in range(6):
    loop = tqdm(loader, desc=f"Epoch {epoch+1}")
    for batch in loop:
        image = batch['image'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(image, input_ids, attention_mask, labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loop.set_postfix(loss=loss.item())

    torch.save(model.state_dict(), f"/content/tikzgen_epoch{epoch+1}.pt")

Epoch 1:   0%|          | 0/450 [00:00<?, ?it/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
Epoch 1: 100%|██████████| 450/450 [04:12<00:00,  1.78it/s, loss=3.62]
Epoch 2: 100%|██████████| 450/450 [04:19<00:00,  1.74it/s, loss=2.89]
Epoch 3: 100%|██████████| 450/450 [04:18<00:00,  1.74it/s, loss=3.35]
Epoch 4: 100%|██████████| 450/450 [04:20<00:00,  1.73it/s, loss=3.64]


KeyboardInterrupt: 

In [27]:
model = TikZGenModel(clip_model, t5_model).to(device)
model.t5_model.config.decoder_start_token_id = tokenizer.pad_token_id
model.load_state_dict(torch.load("/content/tikzgen_epoch2.pt"))

<All keys matched successfully>

In [61]:
def generate_tikz(caption, image_tensor, device="cuda"):
    model.eval()
    model.to(device)

    start_id = tokenizer.eos_token_id if tokenizer.eos_token_id != 0 else tokenizer.pad_token_id
    model.t5_model.config.decoder_start_token_id = tokenizer.pad_token_id  # which is 0



    image_tensor = image_tensor.unsqueeze(0).to(device)  # Add batch dimension
    inputs = tokenizer(caption, return_tensors="pt", padding=True, truncation=True, max_length=64)
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    with torch.no_grad():
        text_features = model.t5_model.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        seq_len = text_features.size(1)
        image_features = model.clip_model.get_image_features(pixel_values=image_tensor).unsqueeze(1)
        image_features = image_features.expand(-1, seq_len, -1)
        fused = model.fusion(torch.cat([text_features, image_features], dim=-1))

        output_ids = model.t5_model.generate(
            #inputs_embeds=fused,
            input_ids=input_ids,
            max_length=256,
            num_beams=4,
            early_stopping=True,
            decoder_start_token_id=model.t5_model.config.decoder_start_token_id,
            eos_token_id=tokenizer.eos_token_id
        )

    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

In [62]:
test_data = preprocessed_test[:20]
#test_data = test_raw.select(range(20))
# Collect predictions and inputs
predictions, references, images, captions = [], [], [], []

for ex in test_data:
    pred = generate_tikz(ex['caption'], ex['image'])
    predictions.append(pred)
    references.append(ex['reference'])  # or ex['reference'] if using preprocessed_data
    images.append(ex['image'])
    captions.append(ex['caption'])

In [70]:
for i in range(5):
    print("\nPrediction:")
    print(predictions[i])
    print("------------------------------")
"""
print("\nCode 1 :")
print(references[9])
"""


Prediction:
function ( plot ) {\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot3\addplot
------------------------------

Prediction:
tikz nodes: centering with small font

I,tikz nodes: centering with small font

I,text

I,tikz nodes: centering with small font

I,tikz nodes: centering
------------------------------

Prediction:
,
------------------------------

Prediction:
box.  function ( )box.box.box.box.box.
------------------------------

Prediction:
{{}{}{  public int size () { return size

'\nprint("\nCode 1 :")\nprint(references[9])\n'

In [18]:
print("Pad token ID:", tokenizer.pad_token_id)
print("Decoder start token ID:", model.t5_model.config.decoder_start_token_id)

Pad token ID: 0
Decoder start token ID: 0


In [73]:
from collections import Counter
from nltk.util import ngrams
import numpy as np

def crystal_bleu(candidate_list, references_list, n=4):
    """
    Simplified CrystalBLEU for evaluating text generation (e.g., TikZ).
    Based on n-gram overlap, with smoothing.
    """
    def count_ngrams(sequence, n):
        return Counter(ngrams(sequence, n)) if len(sequence) >= n else Counter()

    scores = []
    for candidate, references in zip(candidate_list, references_list):
        candidate_tokens = candidate.split()
        reference_tokens = references[0].split()

        precision_scores = []
        for i in range(1, n+1):
            cand_ng = count_ngrams(candidate_tokens, i)
            ref_ng = count_ngrams(reference_tokens, i)

            overlap = sum((cand_ng & ref_ng).values())
            total = max(sum(cand_ng.values()), 1)  # avoid division by zero
            precision_scores.append(overlap / total)

        score = np.exp(np.mean([np.log(p + 1e-8) for p in precision_scores]))  # geometric mean
        scores.append(score)

    return float(np.mean(scores))

In [74]:
cb = crystal_bleu(predictions, [[ref] for ref in references])
print("CrystalBLEU:", cb)

CrystalBLEU: 6.344908721063089e-08
